From 1d948e149c8111a2fe185fd56daa269f184b336e Mon Sep 17 00:00:00 2001 From: Elaina <48662592+oestradiol@users.noreply.github.com> Date: Thu, 12 Sep 2024 03:34:35 -0300 Subject: [PATCH 1/7] Preparing for Draft Pull Request --- Cargo.lock | 412 +++++++++++++++++- Cargo.toml | 13 + atrium-xrpc-wss-client/.gitignore | 2 + atrium-xrpc-wss-client/CHANGELOG.md | 5 + atrium-xrpc-wss-client/Cargo.toml | 27 ++ atrium-xrpc-wss-client/README.md | 1 + atrium-xrpc-wss-client/src/client.rs | 87 ++++ atrium-xrpc-wss-client/src/lib.rs | 6 + .../src/subscriptions/mod.rs | 1 + .../subscriptions/repositories/firehose.rs | 225 ++++++++++ .../src/subscriptions/repositories/mod.rs | 107 +++++ .../subscriptions/repositories/type_defs.rs | 54 +++ atrium-xrpc-wss/.gitignore | 2 + atrium-xrpc-wss/CHANGELOG.md | 5 + atrium-xrpc-wss/Cargo.toml | 21 + atrium-xrpc-wss/README.md | 1 + atrium-xrpc-wss/src/client/mod.rs | 27 ++ atrium-xrpc-wss/src/client/xprc_uri.rs | 16 + atrium-xrpc-wss/src/lib.rs | 4 + .../src/subscriptions/frames/mod.rs | 86 ++++ .../src/subscriptions/frames/tests.rs | 61 +++ .../src/subscriptions/handlers/mod.rs | 3 + .../subscriptions/handlers/repositories.rs | 125 ++++++ atrium-xrpc-wss/src/subscriptions/mod.rs | 81 ++++ examples/firehose/Cargo.toml | 11 +- examples/firehose/src/lib.rs | 2 - examples/firehose/src/main.rs | 207 +++++---- examples/firehose/src/stream.rs | 1 - examples/firehose/src/stream/frames.rs | 158 ------- examples/firehose/src/subscription.rs | 13 - 30 files changed, 1505 insertions(+), 259 deletions(-) create mode 100644 atrium-xrpc-wss-client/.gitignore create mode 100644 atrium-xrpc-wss-client/CHANGELOG.md create mode 100644 atrium-xrpc-wss-client/Cargo.toml create mode 100644 atrium-xrpc-wss-client/README.md create mode 100644 atrium-xrpc-wss-client/src/client.rs create mode 100644 atrium-xrpc-wss-client/src/lib.rs create mode 100644 atrium-xrpc-wss-client/src/subscriptions/mod.rs create mode 100644 atrium-xrpc-wss-client/src/subscriptions/repositories/firehose.rs create mode 100644 atrium-xrpc-wss-client/src/subscriptions/repositories/mod.rs create mode 100644 atrium-xrpc-wss-client/src/subscriptions/repositories/type_defs.rs create mode 100644 atrium-xrpc-wss/.gitignore create mode 100644 atrium-xrpc-wss/CHANGELOG.md create mode 100644 atrium-xrpc-wss/Cargo.toml create mode 100644 atrium-xrpc-wss/README.md create mode 100644 atrium-xrpc-wss/src/client/mod.rs create mode 100644 atrium-xrpc-wss/src/client/xprc_uri.rs create mode 100644 atrium-xrpc-wss/src/lib.rs create mode 100644 atrium-xrpc-wss/src/subscriptions/frames/mod.rs create mode 100644 atrium-xrpc-wss/src/subscriptions/frames/tests.rs create mode 100644 atrium-xrpc-wss/src/subscriptions/handlers/mod.rs create mode 100644 atrium-xrpc-wss/src/subscriptions/handlers/repositories.rs create mode 100644 atrium-xrpc-wss/src/subscriptions/mod.rs delete mode 100644 examples/firehose/src/lib.rs delete mode 100644 examples/firehose/src/stream.rs delete mode 100644 examples/firehose/src/stream/frames.rs delete mode 100644 examples/firehose/src/subscription.rs diff --git a/Cargo.lock b/Cargo.lock index ea0fa279..b9d24778 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -96,6 +96,18 @@ version = "1.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" +[[package]] +name = "arrayref" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d151e35f61089500b617991b791fc8bfd237ae50cd5950803758a179b41e67a" + +[[package]] +name = "arrayvec" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" + [[package]] name = "assert-json-diff" version = "2.0.2" @@ -117,6 +129,28 @@ dependencies = [ "futures-core", ] +[[package]] +name = "async-stream" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd56dd203fef61ac097dd65721a419ddccb106b2d2b70ba60a6b529f03961a51" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.71", +] + [[package]] name = "async-trait" version = "0.1.81" @@ -197,6 +231,38 @@ dependencies = [ "wasm-bindgen-test", ] +[[package]] +name = "atrium-xrpc-wss" +version = "0.1.0" +dependencies = [ + "atrium-api", + "cbor4ii", + "futures", + "ipld-core", + "serde", + "serde_ipld_dagcbor", + "thiserror", +] + +[[package]] +name = "atrium-xrpc-wss-client" +version = "0.1.0" +dependencies = [ + "async-stream", + "atrium-xrpc", + "atrium-xrpc-wss", + "bon", + "futures", + "ipld-core", + "rs-car", + "serde", + "serde_html_form", + "serde_ipld_dagcbor", + "thiserror", + "tokio", + "tokio-tungstenite", +] + [[package]] name = "autocfg" version = "1.3.0" @@ -254,6 +320,17 @@ version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +[[package]] +name = "blake2b_simd" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23285ad32269793932e830392f2fe2f83e26488fd3ec778883a93c8323735780" +dependencies = [ + "arrayref", + "arrayvec", + "constant_time_eq", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -263,6 +340,29 @@ dependencies = [ "generic-array", ] +[[package]] +name = "bon" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee4f37d875011af3196e4828024742a84dcff6b0d027d272f2944f9a99f2c8af" +dependencies = [ + "bon-macros", + "rustversion", +] + +[[package]] +name = "bon-macros" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99b4b686e7ebf76cfa591052482d8c3c8242722518560798631974bf899d5565" +dependencies = [ + "darling", + "ident_case", + "proc-macro2", + "quote", + "syn 2.0.71", +] + [[package]] name = "bsky-cli" version = "0.1.22" @@ -294,7 +394,7 @@ dependencies = [ "serde_json", "thiserror", "tokio", - "toml", + "toml 0.8.15", "unicode-segmentation", ] @@ -304,6 +404,12 @@ version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "bytes" version = "1.6.1" @@ -352,6 +458,19 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "cid" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd94671561e36e4e7de75f753f577edafb0e7c05d6e4547229fdf7938fbcd2c3" +dependencies = [ + "core2", + "multibase", + "multihash 0.18.1", + "serde", + "unsigned-varint 0.7.2", +] + [[package]] name = "cid" version = "0.11.1" @@ -360,7 +479,7 @@ checksum = "3147d8272e8fa0ccd29ce51194dd98f79ddfb8191ba9e3409884e751798acf3a" dependencies = [ "core2", "multibase", - "multihash", + "multihash 0.19.1", "serde", "serde_bytes", "unsigned-varint 0.8.0", @@ -385,7 +504,7 @@ dependencies = [ "anstream", "anstyle", "clap_lex", - "strsim", + "strsim 0.10.0", ] [[package]] @@ -447,6 +566,12 @@ version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" +[[package]] +name = "constant_time_eq" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" + [[package]] name = "core-foundation" version = "0.9.4" @@ -540,6 +665,41 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "darling" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim 0.11.1", + "syn 2.0.71", +] + +[[package]] +name = "darling_macro" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" +dependencies = [ + "darling_core", + "quote", + "syn 2.0.71", +] + [[package]] name = "data-encoding" version = "2.6.0" @@ -736,6 +896,7 @@ checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" dependencies = [ "futures-channel", "futures-core", + "futures-executor", "futures-io", "futures-sink", "futures-task", @@ -758,6 +919,17 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" +[[package]] +name = "futures-executor" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + [[package]] name = "futures-io" version = "0.3.30" @@ -779,6 +951,17 @@ dependencies = [ "waker-fn", ] +[[package]] +name = "futures-macro" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.71", +] + [[package]] name = "futures-sink" version = "0.3.30" @@ -797,11 +980,16 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ + "futures-channel", "futures-core", + "futures-io", + "futures-macro", "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", + "slab", ] [[package]] @@ -1082,6 +1270,12 @@ dependencies = [ "cc", ] +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "idna" version = "0.5.0" @@ -1117,7 +1311,7 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4ede82a79e134f179f4b29b5fdb1eb92bd1b38c4dfea394c539051150a21b9b" dependencies = [ - "cid", + "cid 0.11.1", "serde", "serde_bytes", ] @@ -1209,6 +1403,55 @@ version = "0.2.155" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +[[package]] +name = "libipld" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1ccd6b8ffb3afee7081fcaec00e1b099fd1c7ccf35ba5729d88538fcc3b4599" +dependencies = [ + "fnv", + "libipld-cbor", + "libipld-core", + "libipld-macro", + "log", + "multihash 0.18.1", + "thiserror", +] + +[[package]] +name = "libipld-cbor" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77d98c9d1747aa5eef1cf099cd648c3fd2d235249f5fed07522aaebc348e423b" +dependencies = [ + "byteorder", + "libipld-core", + "thiserror", +] + +[[package]] +name = "libipld-core" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5acd707e8d8b092e967b2af978ed84709eaded82b75effe6cb6f6cc797ef8158" +dependencies = [ + "anyhow", + "cid 0.10.1", + "core2", + "multibase", + "multihash 0.18.1", + "thiserror", +] + +[[package]] +name = "libipld-macro" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71171c54214f866ae6722f3027f81dff0931e600e5a61e6b1b6a49ca0b5ed4ae" +dependencies = [ + "libipld-core", +] + [[package]] name = "libnghttp2-sys" version = "0.1.10+1.61.0" @@ -1325,6 +1568,17 @@ dependencies = [ "data-encoding-macro", ] +[[package]] +name = "multihash" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfd8a792c1694c6da4f68db0a9d707c72bd260994da179e6030a5dcee00bb815" +dependencies = [ + "core2", + "multihash-derive", + "unsigned-varint 0.7.2", +] + [[package]] name = "multihash" version = "0.19.1" @@ -1336,6 +1590,20 @@ dependencies = [ "unsigned-varint 0.7.2", ] +[[package]] +name = "multihash-derive" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d6d4752e6230d8ef7adf7bd5d8c4b1f6561c1014c5ba9a37445ccefe18aa1db" +dependencies = [ + "proc-macro-crate", + "proc-macro-error", + "proc-macro2", + "quote", + "syn 1.0.109", + "synstructure", +] + [[package]] name = "native-tls" version = "0.2.12" @@ -1563,6 +1831,40 @@ dependencies = [ "elliptic-curve", ] +[[package]] +name = "proc-macro-crate" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e17d47ce914bf4de440332250b0edd23ce48c005f59fab39d3335866b114f11a" +dependencies = [ + "thiserror", + "toml 0.5.11", +] + +[[package]] +name = "proc-macro-error" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +dependencies = [ + "proc-macro-error-attr", + "proc-macro2", + "quote", + "syn 1.0.109", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +dependencies = [ + "proc-macro2", + "quote", + "version_check", +] + [[package]] name = "proc-macro2" version = "1.0.86" @@ -1792,6 +2094,18 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rs-car" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf69c4017006c0101362b5df74ee230331703e9938f970468dc1e429afe12998" +dependencies = [ + "blake2b_simd", + "futures", + "libipld", + "sha2", +] + [[package]] name = "rustc-demangle" version = "0.1.24" @@ -1864,6 +2178,12 @@ dependencies = [ "untrusted", ] +[[package]] +name = "rustversion" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" + [[package]] name = "ryu" version = "1.0.18" @@ -2014,6 +2334,17 @@ dependencies = [ "serde", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha2" version = "0.10.8" @@ -2108,6 +2439,12 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + [[package]] name = "subtle" version = "2.6.1" @@ -2142,6 +2479,18 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" +[[package]] +name = "synstructure" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f36bdaa60a83aca3921b5259d5400cbf5e90fc51931376a9bd4a0eb79aa7210f" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", + "unicode-xid", +] + [[package]] name = "tempfile" version = "3.10.1" @@ -2240,6 +2589,20 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38" +dependencies = [ + "futures-util", + "log", + "native-tls", + "tokio", + "tokio-native-tls", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.11" @@ -2253,6 +2616,15 @@ dependencies = [ "tokio", ] +[[package]] +name = "toml" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4f7f0dd8d50a853a531c426359045b1998f04219d88799810762cd4ad314234" +dependencies = [ + "serde", +] + [[package]] name = "toml" version = "0.8.15" @@ -2362,6 +2734,26 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "tungstenite" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http 1.1.0", + "httparse", + "log", + "native-tls", + "rand", + "sha1", + "thiserror", + "url", + "utf-8", +] + [[package]] name = "typenum" version = "1.17.0" @@ -2395,6 +2787,12 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" +[[package]] +name = "unicode-xid" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "229730647fbc343e3a80e463c1db7f78f3855d3f3739bee0dda773c9a037c90a" + [[package]] name = "unsigned-varint" version = "0.7.2" @@ -2424,6 +2822,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8parse" version = "0.2.2" diff --git a/Cargo.toml b/Cargo.toml index ffacaf5e..6f9c9698 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,8 @@ members = [ "atrium-crypto", "atrium-xrpc", "atrium-xrpc-client", + "atrium-xrpc-wss", + "atrium-xrpc-wss-client", "bsky-cli", "bsky-sdk", ] @@ -26,6 +28,8 @@ keywords = ["atproto", "bluesky"] atrium-api = { version = "0.24.4", path = "atrium-api" } atrium-xrpc = { version = "0.11.3", path = "atrium-xrpc" } atrium-xrpc-client = { version = "0.5.6", path = "atrium-xrpc-client" } +atrium-xrpc-wss = { version = "0.1.0", path = "atrium-xrpc-wss" } +atrium-xrpc-wss-client = { version = "0.1.0", path = "atrium-xrpc-wss-client" } bsky-sdk = { version = "0.1.9", path = "bsky-sdk" } # async in traits @@ -35,6 +39,10 @@ async-trait = "0.1.80" # DAG-CBOR codec ipld-core = { version = "0.4.1", default-features = false, features = ["std"] } serde_ipld_dagcbor = { version = "0.6.0", default-features = false, features = ["std"] } +cbor4ii = { version = "0.2.14", default-features = false } + +# CAR files +rs-car = "0.4.1" # Parsing and validation chrono = "0.4" @@ -55,8 +63,10 @@ rand = "0.8.5" # Networking futures = { version = "0.3.30", default-features = false, features = ["alloc"] } +async-stream = "0.3.5" http = "1.1.0" tokio = { version = "1.37", default-features = false } +tokio-tungstenite = { version = "0.21.0", features = ["native-tls"] } # HTTP client integrations isahc = "1.7.2" @@ -76,3 +86,6 @@ mockito = "1.4" # WebAssembly wasm-bindgen-test = "0.3.41" bumpalo = "~3.14.0" + +# Code generation +bon = "2.2.1" \ No newline at end of file diff --git a/atrium-xrpc-wss-client/.gitignore b/atrium-xrpc-wss-client/.gitignore new file mode 100644 index 00000000..4fffb2f8 --- /dev/null +++ b/atrium-xrpc-wss-client/.gitignore @@ -0,0 +1,2 @@ +/target +/Cargo.lock diff --git a/atrium-xrpc-wss-client/CHANGELOG.md b/atrium-xrpc-wss-client/CHANGELOG.md new file mode 100644 index 00000000..df3cff36 --- /dev/null +++ b/atrium-xrpc-wss-client/CHANGELOG.md @@ -0,0 +1,5 @@ +# Changelog +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). \ No newline at end of file diff --git a/atrium-xrpc-wss-client/Cargo.toml b/atrium-xrpc-wss-client/Cargo.toml new file mode 100644 index 00000000..cdb6306b --- /dev/null +++ b/atrium-xrpc-wss-client/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "atrium-xrpc-wss-client" +version = "0.1.0" +authors = ["Elaina <17bestradiol@proton.me>"] +edition.workspace = true +rust-version.workspace = true +description = "XRPC Websocket Client library for AT Protocol (Bluesky)" +documentation = "https://docs.rs/atrium-xrpc-wss-client" +readme = "README.md" +repository.workspace = true +license.workspace = true +keywords.workspace = true + +[dependencies] +atrium-xrpc.workspace = true +atrium-xrpc-wss.workspace = true +futures.workspace = true +ipld-core.workspace = true +async-stream.workspace = true +tokio-tungstenite.workspace = true +serde_ipld_dagcbor.workspace = true +rs-car.workspace = true +tokio.workspace = true +bon.workspace = true +serde_html_form.workspace = true +serde.workspace = true +thiserror.workspace = true \ No newline at end of file diff --git a/atrium-xrpc-wss-client/README.md b/atrium-xrpc-wss-client/README.md new file mode 100644 index 00000000..5e919be9 --- /dev/null +++ b/atrium-xrpc-wss-client/README.md @@ -0,0 +1 @@ +# ATrium XRPC WSS Client \ No newline at end of file diff --git a/atrium-xrpc-wss-client/src/client.rs b/atrium-xrpc-wss-client/src/client.rs new file mode 100644 index 00000000..1db7c11f --- /dev/null +++ b/atrium-xrpc-wss-client/src/client.rs @@ -0,0 +1,87 @@ +//! This file provides a client for the `ATProto` XRPC over WSS protocol. +//! It implements the [`WssClient`] trait for the [`XrpcWssClient`] struct. + +use std::str::FromStr; + +use futures::Stream; +use tokio::net::TcpStream; + +use atrium_xrpc::{ + http::{Request, Uri}, + types::Header, +}; +use bon::Builder; +use serde::Serialize; +use tokio_tungstenite::{ + connect_async, + tungstenite::{self, handshake::client::generate_key}, + MaybeTlsStream, WebSocketStream, +}; + +use atrium_xrpc_wss::client::{WssClient, XrpcUri}; + +/// An enum of possible error kinds for this crate. +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("Invalid uri")] + InvalidUri, + #[error("Parsing parameters failed: {0}")] + ParsingParameters(#[from] serde_html_form::ser::Error), + #[error("Connection error: {0}")] + Connection(#[from] tungstenite::Error), +} + +#[derive(Builder)] +pub struct XrpcWssClient<'a, P: Serialize> { + xrpc_uri: XrpcUri<'a>, + params: Option

, +} + +type StreamKind = WebSocketStream>; +impl WssClient<::Item, Error> + for XrpcWssClient<'_, P> +{ + async fn connect(&self) -> Result::Item>, Error> { + let Self { xrpc_uri, params } = self; + let mut uri = xrpc_uri.to_uri(); + //// Query parameters + if let Some(p) = ¶ms { + uri.push('?'); + uri += &serde_html_form::to_string(p)?; + }; + //// + + //// Request + // Extracting the authority from the URI to set the Host header. + let uri = Uri::from_str(&uri).map_err(|_| Error::InvalidUri)?; + let authority = uri.authority().ok_or_else(|| Error::InvalidUri)?.as_str(); + let host = authority + .find('@') + .map_or_else(|| authority, |idx| authority.split_at(idx + 1).1); + + // Building the request. + let mut request = Request::builder() + .uri(&uri) + .method("GET") + .header("Host", host) + .header("Connection", "Upgrade") + .header("Upgrade", "websocket") + .header("Sec-WebSocket-Version", "13") + .header("Sec-WebSocket-Key", generate_key()); + + // Adding the ATProto headers. + if let Some(proxy) = self.atproto_proxy_header().await { + request = request.header(Header::AtprotoProxy, proxy); + } + if let Some(accept_labelers) = self.atproto_accept_labelers_header().await { + request = request.header(Header::AtprotoAcceptLabelers, accept_labelers.join(", ")); + } + + // In our case, the only thing that could possibly fail is the URI. The headers are all `String`/`&str`. + let request = request.body(()).map_err(|_| Error::InvalidUri)?; + //// + + let (stream, _) = connect_async(request).await?; + Ok(stream) + } +} diff --git a/atrium-xrpc-wss-client/src/lib.rs b/atrium-xrpc-wss-client/src/lib.rs new file mode 100644 index 00000000..ed422338 --- /dev/null +++ b/atrium-xrpc-wss-client/src/lib.rs @@ -0,0 +1,6 @@ +mod client; +pub use client::{Error, XrpcWssClient}; + +pub mod subscriptions; + +pub use atrium_xrpc_wss; // Re-export the atrium_xrpc_wss crate \ No newline at end of file diff --git a/atrium-xrpc-wss-client/src/subscriptions/mod.rs b/atrium-xrpc-wss-client/src/subscriptions/mod.rs new file mode 100644 index 00000000..21b552a0 --- /dev/null +++ b/atrium-xrpc-wss-client/src/subscriptions/mod.rs @@ -0,0 +1 @@ +pub mod repositories; diff --git a/atrium-xrpc-wss-client/src/subscriptions/repositories/firehose.rs b/atrium-xrpc-wss-client/src/subscriptions/repositories/firehose.rs new file mode 100644 index 00000000..c7727f0f --- /dev/null +++ b/atrium-xrpc-wss-client/src/subscriptions/repositories/firehose.rs @@ -0,0 +1,225 @@ +use std::{collections::BTreeMap, io::Cursor}; + +use futures::io::Cursor as FutCursor; +use ipld_core::cid::Cid; + +use super::type_defs::{self, Operation}; +use atrium_xrpc_wss::{ + atrium_api::{ + com::atproto::sync::subscribe_repos::{self, CommitData, InfoData, RepoOpData}, + record::KnownRecord, + types::Object, + }, + subscriptions::{ + handlers::repositories::{HandledData, Handler, ProcessedData}, + ConnectionHandler, ProcessedPayload, + } +}; + +/// Errors for this crate +#[derive(Debug, thiserror::Error)] +pub enum HandlingError { + #[error("CAR Decoding error: {0}")] + CarDecoding(#[from] rs_car::CarDecodeError), + #[error("IPLD Decoding error: {0}")] + IpldDecoding(#[from] serde_ipld_dagcbor::DecodeError), +} + +pub struct Firehose; +impl ConnectionHandler for Firehose { + type HandledData = HandledData; + type HandlingError = self::HandlingError; + + async fn handle_payload( + &self, + t: String, + payload: Vec, + ) -> Result>, Self::HandlingError> { + let res = match t.as_str() { + "#commit" => self + .process_commit(serde_ipld_dagcbor::from_reader(payload.as_slice())?) + .await? + .map(|data| data.map(ProcessedData::Commit)), + "#identity" => self + .process_identity(serde_ipld_dagcbor::from_reader(payload.as_slice())?) + .await? + .map(|data| data.map(ProcessedData::Identity)), + "#account" => self + .process_account(serde_ipld_dagcbor::from_reader(payload.as_slice())?) + .await? + .map(|data| data.map(ProcessedData::Account)), + "#handle" => self + .process_handle(serde_ipld_dagcbor::from_reader(payload.as_slice())?) + .await? + .map(|data| data.map(ProcessedData::Handle)), + "#migrate" => self + .process_migrate(serde_ipld_dagcbor::from_reader(payload.as_slice())?) + .await? + .map(|data| data.map(ProcessedData::Migrate)), + "#tombstone" => self + .process_tombstone(serde_ipld_dagcbor::from_reader(payload.as_slice())?) + .await? + .map(|data| data.map(ProcessedData::Tombstone)), + "#info" => self + .process_info(serde_ipld_dagcbor::from_reader(payload.as_slice())?) + .await? + .map(|data| data.map(ProcessedData::Info)), + _ => { + // "Clients should ignore frames with headers that have unknown op or t values. + // Unknown fields in both headers and payloads should be ignored." + // https://atproto.com/specs/event-stream + return Ok(None); + } + }; + + Ok(res) + } +} + +impl Handler for Firehose { + type ProcessedCommitData = type_defs::ProcessedCommitData; + async fn process_commit( + &self, + payload: subscribe_repos::Commit, + ) -> Result>, Self::HandlingError> { + let CommitData { + blobs, + blocks, + commit, + ops, + repo, + rev, + seq, + since, + time, + too_big, + .. + } = payload.data; + + // If it is too big, the blocks and ops are not sent, so we skip the processing. + let ops_opt = if too_big { + None + } else { + // We read all the blocks from the CAR file and store them in a map + // so that we can look up the data for each operation by its CID. + let mut cursor = FutCursor::new(blocks); + let mut map = rs_car::car_read_all(&mut cursor, true) + .await? + .0 + .into_iter() + .map(compat_cid) + .collect::>(); + + // "Invalid framing or invalid DAG-CBOR encoding are hard errors, + // and the client should drop the entire connection instead of skipping the frame." + // https://atproto.com/specs/event-stream + Some(process_ops(ops, &mut map)?) + }; + + Ok(Some(ProcessedPayload { + seq: Some(seq), + data: Self::ProcessedCommitData { + ops: ops_opt, + blobs, + commit, + repo, + rev, + since, + time, + }, + })) + } + + type ProcessedIdentityData = type_defs::ProcessedIdentityData; + async fn process_identity( + &self, + _payload: subscribe_repos::Identity, + ) -> Result>, Self::HandlingError> { + Ok(None) // TODO: Implement + } + + type ProcessedAccountData = type_defs::ProcessedAccountData; + async fn process_account( + &self, + _payload: subscribe_repos::Account, + ) -> Result>, Self::HandlingError> { + Ok(None) // TODO: Implement + } + + type ProcessedHandleData = type_defs::ProcessedHandleData; + async fn process_handle( + &self, + _payload: subscribe_repos::Handle, + ) -> Result>, Self::HandlingError> { + Ok(None) // TODO: Implement + } + + type ProcessedMigrateData = type_defs::ProcessedMigrateData; + async fn process_migrate( + &self, + _payload: subscribe_repos::Migrate, + ) -> Result>, Self::HandlingError> { + Ok(None) // TODO: Implement + } + + type ProcessedTombstoneData = type_defs::ProcessedTombstoneData; + async fn process_tombstone( + &self, + _payload: subscribe_repos::Tombstone, + ) -> Result>, Self::HandlingError> { + Ok(None) // TODO: Implement + } + + type ProcessedInfoData = InfoData; + async fn process_info( + &self, + payload: subscribe_repos::Info, + ) -> Result>, Self::HandlingError> { + Ok(Some(ProcessedPayload { + seq: None, + data: payload.data, + })) + } +} + +// Transmute is here because the version of the `rs_car` crate for `cid` is 0.10.1 whereas +// the `ilpd_core` crate is 0.11.1. Should work regardless, given that the Cid type's +// memory layout was not changed between the two versions. Temporary fix. +// TODO: Find a better way to fix the version compatibility issue. +fn compat_cid((cid, item): (rs_car::Cid, Vec)) -> (ipld_core::cid::Cid, Vec) { + (unsafe { std::mem::transmute::<_, Cid>(cid) }, item) +} + +fn process_ops( + ops: Vec>, + map: &mut BTreeMap>, +) -> Result, serde_ipld_dagcbor::DecodeError> { + let mut processed_ops = Vec::with_capacity(ops.len()); + for op in ops { + processed_ops.push(process_op(map, op)?); + } + Ok(processed_ops) +} + +/// Processes a single operation. +fn process_op( + map: &mut BTreeMap>, + op: Object, +) -> Result> { + let RepoOpData { action, path, cid } = op.data; + + // Finds in the map the `Record` with the operation's CID and deserializes it. + // If the item is not found, returns `None`. + let record = match cid.as_ref().and_then(|c| map.get_mut(&c.0)) { + Some(item) => Some(serde_ipld_dagcbor::from_reader::( + Cursor::new(item), + )?), + None => None, + }; + + Ok(Operation { + action, + path, + record, + }) +} diff --git a/atrium-xrpc-wss-client/src/subscriptions/repositories/mod.rs b/atrium-xrpc-wss-client/src/subscriptions/repositories/mod.rs new file mode 100644 index 00000000..cdc4110a --- /dev/null +++ b/atrium-xrpc-wss-client/src/subscriptions/repositories/mod.rs @@ -0,0 +1,107 @@ +pub mod firehose; +pub mod type_defs; + +use std::marker::PhantomData; + +use async_stream::stream; +use bon::bon; +use futures::{Stream, StreamExt}; +use tokio_tungstenite::tungstenite::Message; + +use atrium_xrpc_wss::{atrium_api::com::atproto::sync::subscribe_repos, subscriptions::{ + frames::{self, Frame}, + ConnectionHandler, ProcessedPayload, Subscription, SubscriptionError, +}}; + +/// A struct that represents the repositories subscription, used in `com.atproto.sync.subscribeRepos`. +pub struct Repositories { + /// This is only here to constrain the `ConnectionPayload` used in [`Subscription`], or else we get a compile error. + _payload_kind: PhantomData, +} +#[bon] +impl Repositories +where + Self: Subscription, +{ + #[builder] + /// Defines the builder for any generic `Repositories` struct that implements [`Subscription`]. + pub fn new( + connection: impl Stream + Unpin, + handler: H, + ) -> impl Stream, SubscriptionError>> + { + Self::handle_connection(connection, handler) + } +} + +type WssResult = tokio_tungstenite::tungstenite::Result; +impl Subscription for Repositories { + fn handle_connection( + mut connection: impl Stream + Unpin, + handler: H, + ) -> impl Stream, SubscriptionError>> + { + // Builds a new async stream that will deserialize the packets sent through the + // TCP tunnel and then yield the results processed by the handler back to the caller. + let stream = stream! { + loop { + match connection.next().await { + None => break, // Server dropped connection + Some(Err(e)) => { // WebSocket error + // "Invalid framing or invalid DAG-CBOR encoding are hard errors, + // and the client should drop the entire connection instead of skipping the frame." + // https://atproto.com/specs/event-stream + yield Err(SubscriptionError::Abort(format!("Received invalid frame. Error: {e:?}"))); + break; + } + Some(Ok(Message::Binary(data))) => { + match Frame::try_from(data) { + Ok(Frame::Message { t, data: payload }) => { + match handler.handle_payload(t, payload).await { + Ok(Some(res)) => yield Ok(res), // Payload was successfully handled. + Ok(None) => {}, // Payload was ignored by Handler. + Err(e) => { + // "Invalid framing or invalid DAG-CBOR encoding are hard errors, + // and the client should drop the entire connection instead of skipping the frame." + // https://atproto.com/specs/event-stream + yield Err(SubscriptionError::Abort(format!("Received invalid payload. Error: {e:?}"))); + break; + }, + } + }, + Ok(Frame::Error { data }) => { + yield match serde_ipld_dagcbor::from_reader::(data.as_slice()) { + Ok(e) => Err(SubscriptionError::Other(e)), + Err(e) => Err(SubscriptionError::Unknown(format!("Failed to decode error frame: {e:?}"))), + }; + break; + }, + Err(frames::Error::EmptyPayload(ipld)) => { + // "Invalid framing or invalid DAG-CBOR encoding are hard frames::errors, + // and the client should drop the entire connection instead of skipping the frame." + // https://atproto.com/specs/event-stream + yield Err(SubscriptionError::Abort(format!("Received empty payload for header: {ipld:?}"))); + break; + }, + Err(frames::Error::IpldDecoding(e)) => { + // "Invalid framing or invalid DAG-CBOR encoding are hard errors, + // and the client should drop the entire connection instead of skipping the frame." + // https://atproto.com/specs/event-stream + yield Err(SubscriptionError::Abort(format!("Received invalid frame. Error: {e:?}"))); + break; + }, + Err(frames::Error::UnknownFrameType(_)) => { + // "Clients should ignore frames with headers that have unknown op or t values. + // Unknown fields in both headers and payloads should be ignored." + // https://atproto.com/specs/event-stream + }, + } + } + _ => {}, // Ignore other message types. + } + } + }; + + Box::pin(stream) + } +} diff --git a/atrium-xrpc-wss-client/src/subscriptions/repositories/type_defs.rs b/atrium-xrpc-wss-client/src/subscriptions/repositories/type_defs.rs new file mode 100644 index 00000000..193fbd4c --- /dev/null +++ b/atrium-xrpc-wss-client/src/subscriptions/repositories/type_defs.rs @@ -0,0 +1,54 @@ +//! This file defines the types used in the Firehose handler. + +use atrium_xrpc_wss::atrium_api::{ + record::KnownRecord, + types::{ + string::{Datetime, Did}, + CidLink, + }, +}; + +// region: Commit +#[derive(Debug)] +pub struct ProcessedCommitData { + pub repo: Did, + pub commit: CidLink, + // `ops` can be `None` if the commit is marked as `too_big`. + pub ops: Option>, + pub blobs: Vec, + pub rev: String, + pub since: Option, + pub time: Datetime, +} +#[derive(Debug)] +pub struct Operation { + pub action: String, + pub path: String, + pub record: Option, +} +// endregion: Commit + +// region: Identity +#[derive(Debug)] +pub struct ProcessedIdentityData {} +// endregion: Identity + +// region: Account +#[derive(Debug)] +pub struct ProcessedAccountData {} +// endregion: Account + +// region: Handle +#[derive(Debug)] +pub struct ProcessedHandleData {} +// endregion: Handle + +// region: Migrate +#[derive(Debug)] +pub struct ProcessedMigrateData {} +// endregion: Migrate + +// region: Tombstone +#[derive(Debug)] +pub struct ProcessedTombstoneData {} +// endregion: Tombstone diff --git a/atrium-xrpc-wss/.gitignore b/atrium-xrpc-wss/.gitignore new file mode 100644 index 00000000..4fffb2f8 --- /dev/null +++ b/atrium-xrpc-wss/.gitignore @@ -0,0 +1,2 @@ +/target +/Cargo.lock diff --git a/atrium-xrpc-wss/CHANGELOG.md b/atrium-xrpc-wss/CHANGELOG.md new file mode 100644 index 00000000..df3cff36 --- /dev/null +++ b/atrium-xrpc-wss/CHANGELOG.md @@ -0,0 +1,5 @@ +# Changelog +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). \ No newline at end of file diff --git a/atrium-xrpc-wss/Cargo.toml b/atrium-xrpc-wss/Cargo.toml new file mode 100644 index 00000000..d3d760fb --- /dev/null +++ b/atrium-xrpc-wss/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "atrium-xrpc-wss" +version = "0.1.0" +authors = ["Elaina <17bestradiol@proton.me>"] +edition.workspace = true +rust-version.workspace = true +description = "XRPC Websocket library for AT Protocol (Bluesky)" +documentation = "https://docs.rs/atrium-xrpc-wss" +readme = "README.md" +repository.workspace = true +license.workspace = true +keywords.workspace = true + +[dependencies] +atrium-api.workspace = true +futures.workspace = true +cbor4ii.workspace = true +ipld-core.workspace = true +serde.workspace = true +serde_ipld_dagcbor.workspace = true +thiserror.workspace = true \ No newline at end of file diff --git a/atrium-xrpc-wss/README.md b/atrium-xrpc-wss/README.md new file mode 100644 index 00000000..0ada3f05 --- /dev/null +++ b/atrium-xrpc-wss/README.md @@ -0,0 +1 @@ +# ATrium XRPC WSS \ No newline at end of file diff --git a/atrium-xrpc-wss/src/client/mod.rs b/atrium-xrpc-wss/src/client/mod.rs new file mode 100644 index 00000000..09ce804a --- /dev/null +++ b/atrium-xrpc-wss/src/client/mod.rs @@ -0,0 +1,27 @@ +mod xprc_uri; + +use std::future::Future; + +use futures::Stream; +pub use xprc_uri::XrpcUri; + +/// An abstract WSS client. +pub trait WssClient { + /// Send an XRPC request. + /// + /// # Returns + /// [`Result`] + fn connect( + &self, + ) -> impl Future, ConnectionError>> + Send; + + /// Get the `atproto-proxy` header. + fn atproto_proxy_header(&self) -> impl Future> + Send { + async { None } + } + + /// Get the `atproto-accept-labelers` header. + fn atproto_accept_labelers_header(&self) -> impl Future>> + Send { + async { None } + } +} diff --git a/atrium-xrpc-wss/src/client/xprc_uri.rs b/atrium-xrpc-wss/src/client/xprc_uri.rs new file mode 100644 index 00000000..519dc850 --- /dev/null +++ b/atrium-xrpc-wss/src/client/xprc_uri.rs @@ -0,0 +1,16 @@ +/// The URI for the XRPC `WebSocket` connection. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct XrpcUri<'a> { + base_uri: &'a str, + nsid: &'a str, +} +impl<'a> XrpcUri<'a> { + pub const fn new(base_uri: &'a str, nsid: &'a str) -> Self { + Self { base_uri, nsid } + } + + pub fn to_uri(&self) -> String { + let XrpcUri { base_uri, nsid } = self; + format!("wss://{base_uri}/xrpc/{nsid}") + } +} diff --git a/atrium-xrpc-wss/src/lib.rs b/atrium-xrpc-wss/src/lib.rs new file mode 100644 index 00000000..c0b81c79 --- /dev/null +++ b/atrium-xrpc-wss/src/lib.rs @@ -0,0 +1,4 @@ +pub mod client; +pub mod subscriptions; + +pub use atrium_api; // Re-export the atrium_api crate \ No newline at end of file diff --git a/atrium-xrpc-wss/src/subscriptions/frames/mod.rs b/atrium-xrpc-wss/src/subscriptions/frames/mod.rs new file mode 100644 index 00000000..0ec13090 --- /dev/null +++ b/atrium-xrpc-wss/src/subscriptions/frames/mod.rs @@ -0,0 +1,86 @@ +//! This file defines the [`FrameHeader`] and [`Frame`] types, which are used to parse the payloads sent by the subscription through the event stream. +//! You can read more about the specs for these types in the [`ATProto documentation`](https://atproto.com/specs/event-stream) + +#[cfg(test)] +mod tests; + +use cbor4ii::core::utils::IoReader; +use ipld_core::ipld::Ipld; +use serde::Deserialize; +use serde_ipld_dagcbor::de::Deserializer; +use std::io::Cursor; + +/// An error type for this crate. +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("Unknown frame type. Header: {0:?}")] + UnknownFrameType(Ipld), + #[error("Payload was empty. Header: {0:?}")] + EmptyPayload(Ipld), + #[error("Ipld Decoding error: {0}")] + IpldDecoding(#[from] serde_ipld_dagcbor::DecodeError), +} + +/// Represents the header of a frame. It's the first [`Ipld`] object in a Binary payload sent by a subscription. +#[derive(Debug, Clone, PartialEq, Eq)] +enum FrameHeader { + Message { t: String }, + Error, +} + +impl TryFrom for FrameHeader { + type Error = self::Error; + + fn try_from(header: Ipld) -> Result>::Error> { + if let Ipld::Map(ref map) = header { + if let Some(Ipld::Integer(i)) = map.get("op") { + match i { + 1 => { + if let Some(Ipld::String(s)) = map.get("t") { + return Ok(Self::Message { t: s.to_owned() }); + } + } + -1 => return Ok(Self::Error), + _ => {} + } + } + } + Err(Error::UnknownFrameType(header)) + } +} + +/// Represents a frame sent by a subscription. It's the second [`Ipld`] object in a Binary payload sent by a subscription. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Frame { + Message { + t: String, + data: Vec, + }, + Error { + data: Vec, + }, +} + +impl TryFrom> for Frame { + type Error = self::Error; + + fn try_from(value: Vec) -> Result>>::Error> { + let mut cursor = Cursor::new(value); + let mut deserializer = Deserializer::from_reader(IoReader::new(&mut cursor)); + let header = Deserialize::deserialize(&mut deserializer)?; + + // Error means the stream did not end (trailing data), which implies a second IPLD (in this case, the payload). + // If the stream ended, the payload is empty, in which case we error. + let data = if deserializer.end().is_err() { + let pos = cursor.position() as usize; + cursor.get_mut().drain(pos..).collect() + } else { + return Err(Error::EmptyPayload(header)); + }; + + match FrameHeader::try_from(header)? { + FrameHeader::Message { t } => Ok(Self::Message { t, data }), + FrameHeader::Error => Ok(Self::Error { data }), + } + } +} diff --git a/atrium-xrpc-wss/src/subscriptions/frames/tests.rs b/atrium-xrpc-wss/src/subscriptions/frames/tests.rs new file mode 100644 index 00000000..3f8a654d --- /dev/null +++ b/atrium-xrpc-wss/src/subscriptions/frames/tests.rs @@ -0,0 +1,61 @@ +use super::*; + +fn serialized_data(s: &str) -> Vec { + assert!(s.len() % 2 == 0); + let b2u = |b: u8| match b { + b'0'..=b'9' => b - b'0', + b'a'..=b'f' => b - b'a' + 10, + _ => unreachable!(), + }; + s.as_bytes() + .chunks(2) + .map(|b| (b2u(b[0]) << 4) + b2u(b[1])) + .collect() +} + +#[test] +fn deserialize_message_frame_header() { + // {"op": 1, "t": "#commit"} + let data = serialized_data("a2626f700161746723636f6d6d6974"); + let ipld = serde_ipld_dagcbor::from_slice::(&data).expect("failed to deserialize"); + let result = FrameHeader::try_from(ipld); + assert_eq!( + result.expect("failed to deserialize"), + FrameHeader::Message { + t: String::from("#commit") + } + ); +} + +#[test] +fn deserialize_error_frame_header() { + // {"op": -1} + let data = serialized_data("a1626f7020"); + let ipld = serde_ipld_dagcbor::from_slice::(&data).expect("failed to deserialize"); + let result = FrameHeader::try_from(ipld); + assert_eq!(result.expect("failed to deserialize"), FrameHeader::Error); +} + +#[test] +fn deserialize_invalid_frame_header() { + { + // {"op": 2, "t": "#commit"} + let data = serialized_data("a2626f700261746723636f6d6d6974"); + let ipld = serde_ipld_dagcbor::from_slice::(&data).expect("failed to deserialize"); + let result = FrameHeader::try_from(ipld); + assert_eq!( + result.expect_err("must be failed").to_string(), + "Unknown frame type. Header: {\"op\": 2, \"t\": \"#commit\"}" + ); + } + { + // {"op": -2} + let data = serialized_data("a1626f7021"); + let ipld = serde_ipld_dagcbor::from_slice::(&data).expect("failed to deserialize"); + let result = FrameHeader::try_from(ipld); + assert_eq!( + result.expect_err("must be failed").to_string(), + "Unknown frame type. Header: {\"op\": -2}" + ); + } +} diff --git a/atrium-xrpc-wss/src/subscriptions/handlers/mod.rs b/atrium-xrpc-wss/src/subscriptions/handlers/mod.rs new file mode 100644 index 00000000..3250664f --- /dev/null +++ b/atrium-xrpc-wss/src/subscriptions/handlers/mod.rs @@ -0,0 +1,3 @@ +use super::{ConnectionHandler, ProcessedPayload}; + +pub mod repositories; \ No newline at end of file diff --git a/atrium-xrpc-wss/src/subscriptions/handlers/repositories.rs b/atrium-xrpc-wss/src/subscriptions/handlers/repositories.rs new file mode 100644 index 00000000..0e98cdfa --- /dev/null +++ b/atrium-xrpc-wss/src/subscriptions/handlers/repositories.rs @@ -0,0 +1,125 @@ +#![allow(unused_variables)] + +use std::future::Future; + +use atrium_api::com::atproto::sync::subscribe_repos; + +use super::{ConnectionHandler, ProcessedPayload}; + +/// This type should be used to define [`ConnectionHandler::HandledData`](ConnectionHandler::HandledData) +/// for the `com.atproto.sync.subscribeRepos` subscription type. +pub type HandledData = ProcessedData< + ::ProcessedCommitData, + ::ProcessedIdentityData, + ::ProcessedAccountData, + ::ProcessedHandleData, + ::ProcessedMigrateData, + ::ProcessedTombstoneData, + ::ProcessedInfoData, +>; + +/// Wrapper around all the possible types of processed data. +#[derive(Debug)] +pub enum ProcessedData { + Commit(C), + Identity(I0), + Account(A), + Handle(H), + Migrate(M), + Tombstone(T), + Info(I1), +} + +/// A trait that defines a [`ConnectionHandler`] specific to the +/// `com.atproto.sync.subscribeRepos` subscription type. +/// +/// Any struct that fully and correctly implements this trait will be able to +/// handle all the different payload types that the subscription can send. +/// Since the final desired result data type might change for each case, the +/// trait is generic, and the implementor must define the data type for each +/// payload they pretend to use. The same goes for the implementations of +/// each processing method, as the algorithm may vary. +pub trait Handler: ConnectionHandler { + type ProcessedCommitData; + /// Processes a payload of type `#commit`. + fn process_commit( + &self, + payload: subscribe_repos::Commit, + ) -> impl Future< + Output = Result>, Self::HandlingError>, + > { + // Default implementation always returns `None`, meaning the implementation decided to ignore the payload. + async { Ok(None) } + } + + type ProcessedIdentityData; + /// Processes a payload of type `#identity`. + fn process_identity( + &self, + payload: subscribe_repos::Identity, + ) -> impl Future< + Output = Result>, Self::HandlingError>, + > { + // Default implementation always returns `None`, meaning the implementation decided to ignore the payload. + async { Ok(None) } + } + + type ProcessedAccountData; + /// Processes a payload of type `#account`. + fn process_account( + &self, + payload: subscribe_repos::Account, + ) -> impl Future< + Output = Result>, Self::HandlingError>, + > { + // Default implementation always returns `None`, meaning the implementation decided to ignore the payload. + async { Ok(None) } + } + + type ProcessedHandleData; + /// Processes a payload of type `#handle`. + fn process_handle( + &self, + payload: subscribe_repos::Handle, + ) -> impl Future< + Output = Result>, Self::HandlingError>, + > { + // Default implementation always returns `None`, meaning the implementation decided to ignore the payload. + async { Ok(None) } + } + + type ProcessedMigrateData; + /// Processes a payload of type `#migrate`. + fn process_migrate( + &self, + payload: subscribe_repos::Migrate, + ) -> impl Future< + Output = Result>, Self::HandlingError>, + > { + // Default implementation always returns `None`, meaning the implementation decided to ignore the payload. + async { Ok(None) } + } + + type ProcessedTombstoneData; + /// Processes a payload of type `#tombstone`. + fn process_tombstone( + &self, + payload: subscribe_repos::Tombstone, + ) -> impl Future< + Output = Result>, Self::HandlingError>, + > { + // Default implementation always returns `None`, meaning the implementation decided to ignore the payload. + async { Ok(None) } + } + + type ProcessedInfoData; + /// Processes a payload of type `#info`. + fn process_info( + &self, + payload: subscribe_repos::Info, + ) -> impl Future>, Self::HandlingError>> + { + // Default implementation always returns `None`, meaning the implementation decided to ignore the payload. + async { Ok(None) } + } +} diff --git a/atrium-xrpc-wss/src/subscriptions/mod.rs b/atrium-xrpc-wss/src/subscriptions/mod.rs new file mode 100644 index 00000000..a3a3db4b --- /dev/null +++ b/atrium-xrpc-wss/src/subscriptions/mod.rs @@ -0,0 +1,81 @@ +pub mod frames; +pub mod handlers; + +use std::{fmt::Debug, future::Future}; + +use futures::Stream; + +/// A trait that defines the connection handler. +pub trait ConnectionHandler { + /// The [`Self::HandledData`](ConnectionHandler::HandledData) type should be used to define the returned processed data type. + type HandledData; + /// The [`Self::HandlingError`](ConnectionHandler::HandlingError) type should be used to define the processing error type. + type HandlingError: 'static + Send + Sync + Debug; + + /// Handles binary data coming from the connection. This function will deserialize the payload body and call the appropriate + /// handler for each payload type. + /// + /// # Returns + /// [`Result>`] like: + /// - `Ok(Some(processedPayload))` where `processedPayload` is [`ProcessedPayload`](ProcessedPayload) + /// if the payload was successfully processed. + /// - `Ok(None)` if the payload was ignored. + /// - `Err(e)` where `e` is [`ConnectionHandler::HandlingError`] if an error occurred while processing the payload. + fn handle_payload( + &self, + t: String, + payload: Vec, + ) -> impl Future>, Self::HandlingError>>; +} + +/// A trait that defines a subscription. +/// It should be implemented by any struct that wants to handle a connection. +/// The `ConnectionPayload` type parameter is the type of the payload that will be received through the connection stream. +/// The `Error` type parameter is the type of the error that the specific subscription can return, following the lexicon. +pub trait Subscription { + /// The `handle_connection` method should be implemented to handle the connection. + /// + /// # Returns + /// A stream of processed payloads. + fn handle_connection( + connection: impl Stream + Unpin, + handler: H, + ) -> impl Stream, SubscriptionError>>; +} + +/// This struct represents a processed payload. +/// It contains the sequence number (cursor) and the final processed data. +pub struct ProcessedPayload { + pub seq: Option, // Might be absent, like in the case of #info. + pub data: Kind, +} + +/// Helper function to convert between payload kinds. +impl ProcessedPayload { + pub fn map NewKind>(self, f: F) -> ProcessedPayload { + ProcessedPayload { + seq: self.seq, + data: f(self.data), + } + } +} + +/// An error type that represents a subscription error. +/// +/// `Abort` is a hard error, and the subscription should cancel. +/// This follows the [`ATProto Specs`](https://atproto.com/specs/event-stream). +/// +/// `Unknown` is an error that is not recognized by the subscription. +/// This can be used to handle unexpected errors. +/// +/// `Other` is an error specific to the subscription type. +/// This can be used to handle different kinds of errors, following the lexicon. +#[derive(Debug, thiserror::Error)] +pub enum SubscriptionError { + #[error("Critical Subscription Error: {0}")] + Abort(String), + #[error("Unknown Subscription Error: {0}")] + Unknown(String), + #[error(transparent)] + Other(T), +} diff --git a/examples/firehose/Cargo.toml b/examples/firehose/Cargo.toml index 70bab61d..4c2cf0d8 100644 --- a/examples/firehose/Cargo.toml +++ b/examples/firehose/Cargo.toml @@ -6,13 +6,8 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -anyhow = "1.0.80" -atrium-api = { version = "0.18.1", features = ["dag-cbor"] } -chrono = "0.4.34" +anyhow = "1.0.86" +atrium-xrpc-wss-client = { path = "../../atrium-xrpc-wss-client" } futures = "0.3.30" -ipld-core = { version = "0.4.0", default-features = false, features = ["std"] } -rs-car = "0.4.1" -serde_ipld_dagcbor = { version = "0.6.0", default-features = false, features = ["std"] } -tokio = { version = "1.36.0", features = ["full"] } tokio-tungstenite = { version = "0.21.0", features = ["native-tls"] } -trait-variant = "0.1.1" +tokio = { version = "1.36.0", features = ["full"] } \ No newline at end of file diff --git a/examples/firehose/src/lib.rs b/examples/firehose/src/lib.rs deleted file mode 100644 index b4e04262..00000000 --- a/examples/firehose/src/lib.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod stream; -pub mod subscription; diff --git a/examples/firehose/src/main.rs b/examples/firehose/src/main.rs index e70a237e..8a5c8257 100644 --- a/examples/firehose/src/main.rs +++ b/examples/firehose/src/main.rs @@ -1,85 +1,146 @@ -use anyhow::{anyhow, Result}; -use atrium_api::app::bsky::feed::post::Record; -use atrium_api::com::atproto::sync::subscribe_repos::{Commit, NSID}; -use atrium_api::types::{CidLink, Collection}; -use chrono::Local; -use firehose::stream::frames::Frame; -use firehose::subscription::{CommitHandler, Subscription}; +use anyhow::bail; +use atrium_xrpc_wss_client::{ + atrium_xrpc_wss::{ + atrium_api::com::atproto::sync::subscribe_repos::{self, InfoData}, + client::{WssClient, XrpcUri}, + subscriptions::{ + repositories::ProcessedData, + ProcessedPayload, SubscriptionError, + }, + }, + subscriptions::repositories::{ + firehose::Firehose, + type_defs::{Operation, ProcessedCommitData}, Repositories, + }, + Error, XrpcWssClient, +}; use futures::StreamExt; -use tokio::net::TcpStream; -use tokio_tungstenite::tungstenite::Message; -use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream}; +use tokio_tungstenite::tungstenite; -struct RepoSubscription { - stream: WebSocketStream>, -} +/// This example demonstrates how to connect to the ATProto Firehose. +#[tokio::main] +async fn main() { + // Define the XrpcUri for the subscription. + let xrpc_uri = XrpcUri::new("bsky.network", subscribe_repos::NSID); -impl RepoSubscription { - async fn new(bgs: &str) -> Result> { - let (stream, _) = connect_async(format!("wss://{bgs}/xrpc/{NSID}")).await?; - Ok(RepoSubscription { stream }) - } - async fn run(&mut self, handler: impl CommitHandler) -> Result<(), Box> { - while let Some(result) = self.next().await { - if let Ok(Frame::Message(Some(t), message)) = result { - if t.as_str() == "#commit" { - let commit = serde_ipld_dagcbor::from_reader(message.body.as_slice())?; - if let Err(err) = handler.handle_commit(&commit).await { - eprintln!("FAILED: {err:?}"); - } - } - } - } - Ok(()) - } + // Caching the last cursor is important. + // The API has a backfilling mechanism that allows you to resume from where you stopped. + let mut last_cursor = None; + drop(connect(&mut last_cursor, &xrpc_uri).await); } -impl Subscription for RepoSubscription { - async fn next(&mut self) -> Option>::Error>> { - if let Some(Ok(Message::Binary(data))) = self.stream.next().await { - Some(Frame::try_from(data.as_slice())) - } else { - None - } +/// Connects to `ATProto` to receive real-time data. +async fn connect( + last_cursor: &mut Option, + xrpc_uri: &XrpcUri<'_>, +) -> Result<(), anyhow::Error> { + // Define the query parameters. In this case, just the cursor. + let params = subscribe_repos::ParametersData { + cursor: *last_cursor, + }; + + // Build a new XRPC WSS Client. + let client = XrpcWssClient::builder() + .xrpc_uri(xrpc_uri.clone()) + .params(params) + .build(); + + // And then we connect to the API. + let connection = match client.connect().await { + Ok(connection) => connection, + Err(Error::Connection(tungstenite::Error::Http(response))) => { + // According to the API documentation, the following status codes are expected and should be treated accordingly: + // 405 Method Not Allowed: Returned to client for non-GET HTTP requests to a stream endpoint. + // 426 Upgrade Required: Returned to client if Upgrade header is not included in a request to a stream endpoint. + // 429 Too Many Requests: Frequently used for rate-limiting. Client may try again after a delay. Support for the Retry-After header is encouraged. + // 500 Internal Server Error: Client may try again after a delay + // 501 Not Implemented: Service does not implement WebSockets or streams, at least for this endpoint. Client should not try again. + // 502 Bad Gateway, 503 Service Unavailable, 504 Gateway Timeout: Client may try again after a delay. + // https://atproto.com/specs/event-stream + bail!("Status Code was: {response:?}") } -} + Err(e) => bail!(e), + }; -struct Firehose; + // Builds a new subscription from the connection, using handler provided + // by atrium-xrpc-wss-client, the `Firehose`. + let mut subscription = Repositories::builder() + .connection(connection) + .handler(Firehose) + .build(); -impl CommitHandler for Firehose { - async fn handle_commit(&self, commit: &Commit) -> Result<()> { - for op in &commit.ops { - let collection = op.path.split('/').next().expect("op.path is empty"); - if op.action != "create" || collection != atrium_api::app::bsky::feed::Post::NSID { - continue; - } - let (items, _) = rs_car::car_read_all(&mut commit.blocks.as_slice(), true).await?; - if let Some((_, item)) = items.iter().find(|(cid, _)| Some(CidLink(*cid)) == op.cid) { - let record = serde_ipld_dagcbor::from_reader::(&mut item.as_slice())?; - println!( - "{} - {}", - record.created_at.as_ref().with_timezone(&Local), - commit.repo.as_str() - ); - for line in record.text.split('\n') { - println!(" {line}"); - } - } else { - return Err(anyhow!( - "FAILED: could not find item with operation cid {:?} out of {} items", - op.cid, - items.len() - )); - } + // Receive payloads by calling `StreamExt::next()`. + while let Some(payload) = subscription.next().await { + let data = match payload { + Ok(ProcessedPayload { seq, data }) => { + if let Some(seq) = seq { + *last_cursor = Some(seq); } - Ok(()) - } + data + } + Err(SubscriptionError::Abort(reason)) => { + // This could mean multiple things, all of which are critical errors that require + // immediate termination of connection. + eprintln!("Aborted: {reason}"); + *last_cursor = None; + break; + } + Err(e) => { + // Errors such as `FutureCursor` and `ConsumerTooSlow` can be dealt with here. + eprintln!("{e:?}"); + *last_cursor = None; + break; + } + }; + + match data { + ProcessedData::Commit(data) => beauty_print_commit(data), + ProcessedData::Info(InfoData { message, name }) => { + println!("Received info. Message: {message:?}; Name: {name}."); + } + _ => { /* Ignored */ } + }; + } + + Ok(()) } -#[tokio::main] -async fn main() -> Result<(), Box> { - RepoSubscription::new("bsky.network") - .await? - .run(Firehose) - .await +fn beauty_print_commit(data: ProcessedCommitData) { + let ProcessedCommitData { + repo, commit, ops, .. + } = data; + if let Some(ops) = ops { + for r in ops { + let Operation { + action, + path, + record, + } = r; + let print = format!( + "\n\n\n################################# {} ##################################\n\ + - Repository (User DID): {}\n\ + - Commit CID: {}\n\ + - Path: {path}\n\ + - Flagged as \"too big\"? ", + action.to_uppercase(), + repo.as_str(), + commit.0, + ); + // Record is only `None` when the commit was flagged as "too big". + if let Some(record) = record { + println!( + "{}No\n\ + //-------------------------------- Record Info -------------------------------//\n\n\ + {:?}", + print, record + ); + } else { + println!( + "{}Yes\n\ + //---------------------------------------------------------------------------//\n\n", + print + ); + } + } + } } diff --git a/examples/firehose/src/stream.rs b/examples/firehose/src/stream.rs deleted file mode 100644 index b63fcef5..00000000 --- a/examples/firehose/src/stream.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod frames; diff --git a/examples/firehose/src/stream/frames.rs b/examples/firehose/src/stream/frames.rs deleted file mode 100644 index 3edd1c1e..00000000 --- a/examples/firehose/src/stream/frames.rs +++ /dev/null @@ -1,158 +0,0 @@ -use ipld_core::ipld::Ipld; -use std::io::Cursor; - -// original definition: -//``` -// export enum FrameType { -// Message = 1, -// Error = -1, -// } -// export const messageFrameHeader = z.object({ -// op: z.literal(FrameType.Message), // Frame op -// t: z.string().optional(), // Message body type discriminator -// }) -// export type MessageFrameHeader = z.infer -// export const errorFrameHeader = z.object({ -// op: z.literal(FrameType.Error), -// }) -// export type ErrorFrameHeader = z.infer -// ``` -#[derive(Debug, Clone, PartialEq, Eq)] -enum FrameHeader { - Message(Option), - Error, -} - -impl TryFrom for FrameHeader { - type Error = anyhow::Error; - - fn try_from(value: Ipld) -> Result>::Error> { - if let Ipld::Map(map) = value { - if let Some(Ipld::Integer(i)) = map.get("op") { - match i { - 1 => { - let t = if let Some(Ipld::String(s)) = map.get("t") { - Some(s.clone()) - } else { - None - }; - return Ok(FrameHeader::Message(t)); - } - -1 => return Ok(FrameHeader::Error), - _ => {} - } - } - } - Err(anyhow::anyhow!("invalid frame type")) - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum Frame { - Message(Option, MessageFrame), - Error(ErrorFrame), -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct MessageFrame { - pub body: Vec, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct ErrorFrame { - // TODO - // body: Value, -} - -impl TryFrom<&[u8]> for Frame { - type Error = anyhow::Error; - - fn try_from(value: &[u8]) -> Result>::Error> { - let mut cursor = Cursor::new(value); - let (left, right) = match serde_ipld_dagcbor::from_reader::(&mut cursor) { - Err(serde_ipld_dagcbor::DecodeError::TrailingData) => { - value.split_at(cursor.position() as usize) - } - _ => { - // TODO - return Err(anyhow::anyhow!("invalid frame type")); - } - }; - let header = FrameHeader::try_from(serde_ipld_dagcbor::from_slice::(left)?)?; - if let FrameHeader::Message(t) = &header { - Ok(Frame::Message( - t.clone(), - MessageFrame { - body: right.to_vec(), - }, - )) - } else { - Ok(Frame::Error(ErrorFrame {})) - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - fn serialized_data(s: &str) -> Vec { - assert!(s.len() % 2 == 0); - let b2u = |b: u8| match b { - b'0'..=b'9' => b - b'0', - b'a'..=b'f' => b - b'a' + 10, - _ => unreachable!(), - }; - s.as_bytes() - .chunks(2) - .map(|b| (b2u(b[0]) << 4) + b2u(b[1])) - .collect() - } - - #[test] - fn deserialize_message_frame_header() { - // {"op": 1, "t": "#commit"} - let data = serialized_data("a2626f700161746723636f6d6d6974"); - let ipld = serde_ipld_dagcbor::from_slice::(&data).expect("failed to deserialize"); - let result = FrameHeader::try_from(ipld); - assert_eq!( - result.expect("failed to deserialize"), - FrameHeader::Message(Some(String::from("#commit"))) - ); - } - - #[test] - fn deserialize_error_frame_header() { - // {"op": -1} - let data = serialized_data("a1626f7020"); - let ipld = serde_ipld_dagcbor::from_slice::(&data).expect("failed to deserialize"); - let result = FrameHeader::try_from(ipld); - assert_eq!(result.expect("failed to deserialize"), FrameHeader::Error); - } - - #[test] - fn deserialize_invalid_frame_header() { - { - // {"op": 2, "t": "#commit"} - let data = serialized_data("a2626f700261746723636f6d6d6974"); - let ipld = - serde_ipld_dagcbor::from_slice::(&data).expect("failed to deserialize"); - let result = FrameHeader::try_from(ipld); - assert_eq!( - result.expect_err("must be failed").to_string(), - "invalid frame type" - ); - } - { - // {"op": -2} - let data = serialized_data("a1626f7021"); - let ipld = - serde_ipld_dagcbor::from_slice::(&data).expect("failed to deserialize"); - let result = FrameHeader::try_from(ipld); - assert_eq!( - result.expect_err("must be failed").to_string(), - "invalid frame type" - ); - } - } -} diff --git a/examples/firehose/src/subscription.rs b/examples/firehose/src/subscription.rs deleted file mode 100644 index 90393105..00000000 --- a/examples/firehose/src/subscription.rs +++ /dev/null @@ -1,13 +0,0 @@ -use crate::stream::frames::Frame; -use anyhow::Result; -use atrium_api::com::atproto::sync::subscribe_repos::Commit; -use std::future::Future; - -#[trait_variant::make(HttpService: Send)] -pub trait Subscription { - async fn next(&mut self) -> Option>::Error>>; -} - -pub trait CommitHandler { - fn handle_commit(&self, commit: &Commit) -> impl Future>; -} From 26a01d71664a59941d3375b7195a8998c19e0979 Mon Sep 17 00:00:00 2001 From: Elaina <48662592+oestradiol@users.noreply.github.com> Date: Thu, 12 Sep 2024 04:13:22 -0300 Subject: [PATCH 2/7] Formatting --- README.md | 8 + atrium-xrpc-wss-client/src/client.rs | 110 ++--- atrium-xrpc-wss-client/src/lib.rs | 4 +- .../subscriptions/repositories/firehose.rs | 378 +++++++++--------- .../src/subscriptions/repositories/mod.rs | 155 +++---- .../subscriptions/repositories/type_defs.rs | 32 +- atrium-xrpc-wss/src/client/mod.rs | 32 +- atrium-xrpc-wss/src/client/xprc_uri.rs | 18 +- atrium-xrpc-wss/src/lib.rs | 2 +- .../src/subscriptions/frames/mod.rs | 87 ++-- .../src/subscriptions/frames/tests.rs | 90 ++--- .../src/subscriptions/handlers/mod.rs | 2 +- .../subscriptions/handlers/repositories.rs | 184 ++++----- atrium-xrpc-wss/src/subscriptions/mod.rs | 78 ++-- examples/firehose/src/main.rs | 224 +++++------ 15 files changed, 708 insertions(+), 696 deletions(-) diff --git a/README.md b/README.md index 0822a1f0..be5c4574 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,14 @@ Definitions for XRPC request/response, and their associated errors. A library provides clients that implement the `XrpcClient` defined in [atrium-xrpc](./atrium-xrpc/) +### [`atrium-xrpc-wss`](./atrium-xrpc-wss/) + +Definitions for traits, types and utilities for dealing with WebSocket XRPC subscriptions. (WIP) + +### [`atrium-xrpc-wss-client`](./atrium-xrpc-wss-client/) + +A library that provides default implementations of the `XrpcWssClient`, `Handlers` and `Subscription` defined in [atrium-xrpc-wss](./atrium-xrpc-wss/) for interacting with the variety of subscriptions in ATProto (WIP) + ### [`bsky-sdk`](./bsky-sdk/) [![](https://img.shields.io/crates/v/bsky-sdk)](https://crates.io/crates/bsky-sdk) diff --git a/atrium-xrpc-wss-client/src/client.rs b/atrium-xrpc-wss-client/src/client.rs index 1db7c11f..a6deb84e 100644 --- a/atrium-xrpc-wss-client/src/client.rs +++ b/atrium-xrpc-wss-client/src/client.rs @@ -7,81 +7,81 @@ use futures::Stream; use tokio::net::TcpStream; use atrium_xrpc::{ - http::{Request, Uri}, - types::Header, + http::{Request, Uri}, + types::Header, }; use bon::Builder; use serde::Serialize; use tokio_tungstenite::{ - connect_async, - tungstenite::{self, handshake::client::generate_key}, - MaybeTlsStream, WebSocketStream, + connect_async, + tungstenite::{self, handshake::client::generate_key}, + MaybeTlsStream, WebSocketStream, }; -use atrium_xrpc_wss::client::{WssClient, XrpcUri}; +use atrium_xrpc_wss::client::{XrpcUri, XrpcWssClient}; /// An enum of possible error kinds for this crate. #[derive(thiserror::Error, Debug)] pub enum Error { - #[error("Invalid uri")] - InvalidUri, - #[error("Parsing parameters failed: {0}")] - ParsingParameters(#[from] serde_html_form::ser::Error), - #[error("Connection error: {0}")] - Connection(#[from] tungstenite::Error), + #[error("Invalid uri")] + InvalidUri, + #[error("Parsing parameters failed: {0}")] + ParsingParameters(#[from] serde_html_form::ser::Error), + #[error("Connection error: {0}")] + Connection(#[from] tungstenite::Error), } #[derive(Builder)] -pub struct XrpcWssClient<'a, P: Serialize> { - xrpc_uri: XrpcUri<'a>, - params: Option

, +pub struct DefaultClient<'a, P: Serialize> { + xrpc_uri: XrpcUri<'a>, + params: Option

, } type StreamKind = WebSocketStream>; -impl WssClient<::Item, Error> - for XrpcWssClient<'_, P> +impl XrpcWssClient<::Item, Error> + for DefaultClient<'_, P> { - async fn connect(&self) -> Result::Item>, Error> { - let Self { xrpc_uri, params } = self; - let mut uri = xrpc_uri.to_uri(); - //// Query parameters - if let Some(p) = ¶ms { - uri.push('?'); - uri += &serde_html_form::to_string(p)?; - }; - //// + async fn connect(&self) -> Result::Item>, Error> { + let Self { xrpc_uri, params } = self; + let mut uri = xrpc_uri.to_uri(); + //// Query parameters + if let Some(p) = ¶ms { + uri.push('?'); + uri += &serde_html_form::to_string(p)?; + }; + //// - //// Request - // Extracting the authority from the URI to set the Host header. - let uri = Uri::from_str(&uri).map_err(|_| Error::InvalidUri)?; - let authority = uri.authority().ok_or_else(|| Error::InvalidUri)?.as_str(); - let host = authority - .find('@') - .map_or_else(|| authority, |idx| authority.split_at(idx + 1).1); + //// Request + // Extracting the authority from the URI to set the Host header. + let uri = Uri::from_str(&uri).map_err(|_| Error::InvalidUri)?; + let authority = uri.authority().ok_or_else(|| Error::InvalidUri)?.as_str(); + let host = authority + .find('@') + .map_or_else(|| authority, |idx| authority.split_at(idx + 1).1); - // Building the request. - let mut request = Request::builder() - .uri(&uri) - .method("GET") - .header("Host", host) - .header("Connection", "Upgrade") - .header("Upgrade", "websocket") - .header("Sec-WebSocket-Version", "13") - .header("Sec-WebSocket-Key", generate_key()); + // Building the request. + let mut request = Request::builder() + .uri(&uri) + .method("GET") + .header("Host", host) + .header("Connection", "Upgrade") + .header("Upgrade", "websocket") + .header("Sec-WebSocket-Version", "13") + .header("Sec-WebSocket-Key", generate_key()); - // Adding the ATProto headers. - if let Some(proxy) = self.atproto_proxy_header().await { - request = request.header(Header::AtprotoProxy, proxy); - } - if let Some(accept_labelers) = self.atproto_accept_labelers_header().await { - request = request.header(Header::AtprotoAcceptLabelers, accept_labelers.join(", ")); - } + // Adding the ATProto headers. + if let Some(proxy) = self.atproto_proxy_header().await { + request = request.header(Header::AtprotoProxy, proxy); + } + if let Some(accept_labelers) = self.atproto_accept_labelers_header().await { + request = request.header(Header::AtprotoAcceptLabelers, accept_labelers.join(", ")); + } - // In our case, the only thing that could possibly fail is the URI. The headers are all `String`/`&str`. - let request = request.body(()).map_err(|_| Error::InvalidUri)?; - //// + // In our case, the only thing that could possibly fail is the URI. The headers are all `String`/`&str`. + let request = request.body(()).map_err(|_| Error::InvalidUri)?; + //// - let (stream, _) = connect_async(request).await?; - Ok(stream) - } + let (stream, _) = connect_async(request).await?; + Ok(stream) + } } diff --git a/atrium-xrpc-wss-client/src/lib.rs b/atrium-xrpc-wss-client/src/lib.rs index ed422338..b1159b29 100644 --- a/atrium-xrpc-wss-client/src/lib.rs +++ b/atrium-xrpc-wss-client/src/lib.rs @@ -1,6 +1,6 @@ mod client; -pub use client::{Error, XrpcWssClient}; +pub use client::{DefaultClient, Error}; pub mod subscriptions; -pub use atrium_xrpc_wss; // Re-export the atrium_xrpc_wss crate \ No newline at end of file +pub use atrium_xrpc_wss; // Re-export the atrium_xrpc_wss crate diff --git a/atrium-xrpc-wss-client/src/subscriptions/repositories/firehose.rs b/atrium-xrpc-wss-client/src/subscriptions/repositories/firehose.rs index c7727f0f..d1bba6ab 100644 --- a/atrium-xrpc-wss-client/src/subscriptions/repositories/firehose.rs +++ b/atrium-xrpc-wss-client/src/subscriptions/repositories/firehose.rs @@ -5,181 +5,181 @@ use ipld_core::cid::Cid; use super::type_defs::{self, Operation}; use atrium_xrpc_wss::{ - atrium_api::{ - com::atproto::sync::subscribe_repos::{self, CommitData, InfoData, RepoOpData}, - record::KnownRecord, - types::Object, - }, - subscriptions::{ - handlers::repositories::{HandledData, Handler, ProcessedData}, - ConnectionHandler, ProcessedPayload, - } + atrium_api::{ + com::atproto::sync::subscribe_repos::{self, CommitData, InfoData, RepoOpData}, + record::KnownRecord, + types::Object, + }, + subscriptions::{ + handlers::repositories::{HandledData, Handler, ProcessedData}, + ConnectionHandler, ProcessedPayload, + }, }; /// Errors for this crate #[derive(Debug, thiserror::Error)] pub enum HandlingError { - #[error("CAR Decoding error: {0}")] - CarDecoding(#[from] rs_car::CarDecodeError), - #[error("IPLD Decoding error: {0}")] - IpldDecoding(#[from] serde_ipld_dagcbor::DecodeError), + #[error("CAR Decoding error: {0}")] + CarDecoding(#[from] rs_car::CarDecodeError), + #[error("IPLD Decoding error: {0}")] + IpldDecoding(#[from] serde_ipld_dagcbor::DecodeError), } pub struct Firehose; impl ConnectionHandler for Firehose { - type HandledData = HandledData; - type HandlingError = self::HandlingError; - - async fn handle_payload( - &self, - t: String, - payload: Vec, - ) -> Result>, Self::HandlingError> { - let res = match t.as_str() { - "#commit" => self - .process_commit(serde_ipld_dagcbor::from_reader(payload.as_slice())?) - .await? - .map(|data| data.map(ProcessedData::Commit)), - "#identity" => self - .process_identity(serde_ipld_dagcbor::from_reader(payload.as_slice())?) - .await? - .map(|data| data.map(ProcessedData::Identity)), - "#account" => self - .process_account(serde_ipld_dagcbor::from_reader(payload.as_slice())?) - .await? - .map(|data| data.map(ProcessedData::Account)), - "#handle" => self - .process_handle(serde_ipld_dagcbor::from_reader(payload.as_slice())?) - .await? - .map(|data| data.map(ProcessedData::Handle)), - "#migrate" => self - .process_migrate(serde_ipld_dagcbor::from_reader(payload.as_slice())?) - .await? - .map(|data| data.map(ProcessedData::Migrate)), - "#tombstone" => self - .process_tombstone(serde_ipld_dagcbor::from_reader(payload.as_slice())?) - .await? - .map(|data| data.map(ProcessedData::Tombstone)), - "#info" => self - .process_info(serde_ipld_dagcbor::from_reader(payload.as_slice())?) - .await? - .map(|data| data.map(ProcessedData::Info)), - _ => { - // "Clients should ignore frames with headers that have unknown op or t values. - // Unknown fields in both headers and payloads should be ignored." - // https://atproto.com/specs/event-stream - return Ok(None); - } - }; - - Ok(res) - } + type HandledData = HandledData; + type HandlingError = self::HandlingError; + + async fn handle_payload( + &self, + t: String, + payload: Vec, + ) -> Result>, Self::HandlingError> { + let res = match t.as_str() { + "#commit" => self + .process_commit(serde_ipld_dagcbor::from_reader(payload.as_slice())?) + .await? + .map(|data| data.map(ProcessedData::Commit)), + "#identity" => self + .process_identity(serde_ipld_dagcbor::from_reader(payload.as_slice())?) + .await? + .map(|data| data.map(ProcessedData::Identity)), + "#account" => self + .process_account(serde_ipld_dagcbor::from_reader(payload.as_slice())?) + .await? + .map(|data| data.map(ProcessedData::Account)), + "#handle" => self + .process_handle(serde_ipld_dagcbor::from_reader(payload.as_slice())?) + .await? + .map(|data| data.map(ProcessedData::Handle)), + "#migrate" => self + .process_migrate(serde_ipld_dagcbor::from_reader(payload.as_slice())?) + .await? + .map(|data| data.map(ProcessedData::Migrate)), + "#tombstone" => self + .process_tombstone(serde_ipld_dagcbor::from_reader(payload.as_slice())?) + .await? + .map(|data| data.map(ProcessedData::Tombstone)), + "#info" => self + .process_info(serde_ipld_dagcbor::from_reader(payload.as_slice())?) + .await? + .map(|data| data.map(ProcessedData::Info)), + _ => { + // "Clients should ignore frames with headers that have unknown op or t values. + // Unknown fields in both headers and payloads should be ignored." + // https://atproto.com/specs/event-stream + return Ok(None); + } + }; + + Ok(res) + } } impl Handler for Firehose { - type ProcessedCommitData = type_defs::ProcessedCommitData; - async fn process_commit( - &self, - payload: subscribe_repos::Commit, - ) -> Result>, Self::HandlingError> { - let CommitData { - blobs, - blocks, - commit, - ops, - repo, - rev, - seq, - since, - time, - too_big, - .. - } = payload.data; - - // If it is too big, the blocks and ops are not sent, so we skip the processing. - let ops_opt = if too_big { - None - } else { - // We read all the blocks from the CAR file and store them in a map - // so that we can look up the data for each operation by its CID. - let mut cursor = FutCursor::new(blocks); - let mut map = rs_car::car_read_all(&mut cursor, true) - .await? - .0 - .into_iter() - .map(compat_cid) - .collect::>(); - - // "Invalid framing or invalid DAG-CBOR encoding are hard errors, - // and the client should drop the entire connection instead of skipping the frame." - // https://atproto.com/specs/event-stream - Some(process_ops(ops, &mut map)?) - }; - - Ok(Some(ProcessedPayload { - seq: Some(seq), - data: Self::ProcessedCommitData { - ops: ops_opt, - blobs, - commit, - repo, - rev, - since, - time, - }, - })) - } - - type ProcessedIdentityData = type_defs::ProcessedIdentityData; - async fn process_identity( - &self, - _payload: subscribe_repos::Identity, - ) -> Result>, Self::HandlingError> { - Ok(None) // TODO: Implement - } - - type ProcessedAccountData = type_defs::ProcessedAccountData; - async fn process_account( - &self, - _payload: subscribe_repos::Account, - ) -> Result>, Self::HandlingError> { - Ok(None) // TODO: Implement - } - - type ProcessedHandleData = type_defs::ProcessedHandleData; - async fn process_handle( - &self, - _payload: subscribe_repos::Handle, - ) -> Result>, Self::HandlingError> { - Ok(None) // TODO: Implement - } - - type ProcessedMigrateData = type_defs::ProcessedMigrateData; - async fn process_migrate( - &self, - _payload: subscribe_repos::Migrate, - ) -> Result>, Self::HandlingError> { - Ok(None) // TODO: Implement - } - - type ProcessedTombstoneData = type_defs::ProcessedTombstoneData; - async fn process_tombstone( - &self, - _payload: subscribe_repos::Tombstone, - ) -> Result>, Self::HandlingError> { - Ok(None) // TODO: Implement - } - - type ProcessedInfoData = InfoData; - async fn process_info( - &self, - payload: subscribe_repos::Info, - ) -> Result>, Self::HandlingError> { - Ok(Some(ProcessedPayload { - seq: None, - data: payload.data, - })) - } + type ProcessedCommitData = type_defs::ProcessedCommitData; + async fn process_commit( + &self, + payload: subscribe_repos::Commit, + ) -> Result>, Self::HandlingError> { + let CommitData { + blobs, + blocks, + commit, + ops, + repo, + rev, + seq, + since, + time, + too_big, + .. + } = payload.data; + + // If it is too big, the blocks and ops are not sent, so we skip the processing. + let ops_opt = if too_big { + None + } else { + // We read all the blocks from the CAR file and store them in a map + // so that we can look up the data for each operation by its CID. + let mut cursor = FutCursor::new(blocks); + let mut map = rs_car::car_read_all(&mut cursor, true) + .await? + .0 + .into_iter() + .map(compat_cid) + .collect::>(); + + // "Invalid framing or invalid DAG-CBOR encoding are hard errors, + // and the client should drop the entire connection instead of skipping the frame." + // https://atproto.com/specs/event-stream + Some(process_ops(ops, &mut map)?) + }; + + Ok(Some(ProcessedPayload { + seq: Some(seq), + data: Self::ProcessedCommitData { + ops: ops_opt, + blobs, + commit, + repo, + rev, + since, + time, + }, + })) + } + + type ProcessedIdentityData = type_defs::ProcessedIdentityData; + async fn process_identity( + &self, + _payload: subscribe_repos::Identity, + ) -> Result>, Self::HandlingError> { + Ok(None) // TODO: Implement + } + + type ProcessedAccountData = type_defs::ProcessedAccountData; + async fn process_account( + &self, + _payload: subscribe_repos::Account, + ) -> Result>, Self::HandlingError> { + Ok(None) // TODO: Implement + } + + type ProcessedHandleData = type_defs::ProcessedHandleData; + async fn process_handle( + &self, + _payload: subscribe_repos::Handle, + ) -> Result>, Self::HandlingError> { + Ok(None) // TODO: Implement + } + + type ProcessedMigrateData = type_defs::ProcessedMigrateData; + async fn process_migrate( + &self, + _payload: subscribe_repos::Migrate, + ) -> Result>, Self::HandlingError> { + Ok(None) // TODO: Implement + } + + type ProcessedTombstoneData = type_defs::ProcessedTombstoneData; + async fn process_tombstone( + &self, + _payload: subscribe_repos::Tombstone, + ) -> Result>, Self::HandlingError> { + Ok(None) // TODO: Implement + } + + type ProcessedInfoData = InfoData; + async fn process_info( + &self, + payload: subscribe_repos::Info, + ) -> Result>, Self::HandlingError> { + Ok(Some(ProcessedPayload { + seq: None, + data: payload.data, + })) + } } // Transmute is here because the version of the `rs_car` crate for `cid` is 0.10.1 whereas @@ -187,39 +187,39 @@ impl Handler for Firehose { // memory layout was not changed between the two versions. Temporary fix. // TODO: Find a better way to fix the version compatibility issue. fn compat_cid((cid, item): (rs_car::Cid, Vec)) -> (ipld_core::cid::Cid, Vec) { - (unsafe { std::mem::transmute::<_, Cid>(cid) }, item) + (unsafe { std::mem::transmute::<_, Cid>(cid) }, item) } fn process_ops( - ops: Vec>, - map: &mut BTreeMap>, + ops: Vec>, + map: &mut BTreeMap>, ) -> Result, serde_ipld_dagcbor::DecodeError> { - let mut processed_ops = Vec::with_capacity(ops.len()); - for op in ops { - processed_ops.push(process_op(map, op)?); - } - Ok(processed_ops) + let mut processed_ops = Vec::with_capacity(ops.len()); + for op in ops { + processed_ops.push(process_op(map, op)?); + } + Ok(processed_ops) } /// Processes a single operation. fn process_op( - map: &mut BTreeMap>, - op: Object, + map: &mut BTreeMap>, + op: Object, ) -> Result> { - let RepoOpData { action, path, cid } = op.data; - - // Finds in the map the `Record` with the operation's CID and deserializes it. - // If the item is not found, returns `None`. - let record = match cid.as_ref().and_then(|c| map.get_mut(&c.0)) { - Some(item) => Some(serde_ipld_dagcbor::from_reader::( - Cursor::new(item), - )?), - None => None, - }; - - Ok(Operation { - action, - path, - record, - }) + let RepoOpData { action, path, cid } = op.data; + + // Finds in the map the `Record` with the operation's CID and deserializes it. + // If the item is not found, returns `None`. + let record = match cid.as_ref().and_then(|c| map.get_mut(&c.0)) { + Some(item) => Some(serde_ipld_dagcbor::from_reader::( + Cursor::new(item), + )?), + None => None, + }; + + Ok(Operation { + action, + path, + record, + }) } diff --git a/atrium-xrpc-wss-client/src/subscriptions/repositories/mod.rs b/atrium-xrpc-wss-client/src/subscriptions/repositories/mod.rs index cdc4110a..7ed7e68c 100644 --- a/atrium-xrpc-wss-client/src/subscriptions/repositories/mod.rs +++ b/atrium-xrpc-wss-client/src/subscriptions/repositories/mod.rs @@ -8,100 +8,105 @@ use bon::bon; use futures::{Stream, StreamExt}; use tokio_tungstenite::tungstenite::Message; -use atrium_xrpc_wss::{atrium_api::com::atproto::sync::subscribe_repos, subscriptions::{ - frames::{self, Frame}, - ConnectionHandler, ProcessedPayload, Subscription, SubscriptionError, -}}; +use atrium_xrpc_wss::{ + atrium_api::com::atproto::sync::subscribe_repos, + subscriptions::{ + frames::{self, Frame}, + ConnectionHandler, ProcessedPayload, Subscription, SubscriptionError, + }, +}; /// A struct that represents the repositories subscription, used in `com.atproto.sync.subscribeRepos`. pub struct Repositories { - /// This is only here to constrain the `ConnectionPayload` used in [`Subscription`], or else we get a compile error. - _payload_kind: PhantomData, + /// This is only here to constrain the `ConnectionPayload` used in [`Subscription`], or else we get a compile error. + _payload_kind: PhantomData, } #[bon] impl Repositories where - Self: Subscription, + Self: Subscription, { - #[builder] - /// Defines the builder for any generic `Repositories` struct that implements [`Subscription`]. - pub fn new( - connection: impl Stream + Unpin, - handler: H, - ) -> impl Stream, SubscriptionError>> - { - Self::handle_connection(connection, handler) - } + #[builder] + /// Defines the builder for any generic `Repositories` struct that implements [`Subscription`]. + pub fn new( + connection: impl Stream + Unpin, + handler: H, + ) -> impl Stream< + Item = Result, SubscriptionError>, + > { + Self::handle_connection(connection, handler) + } } type WssResult = tokio_tungstenite::tungstenite::Result; impl Subscription for Repositories { - fn handle_connection( - mut connection: impl Stream + Unpin, - handler: H, - ) -> impl Stream, SubscriptionError>> - { - // Builds a new async stream that will deserialize the packets sent through the - // TCP tunnel and then yield the results processed by the handler back to the caller. - let stream = stream! { - loop { - match connection.next().await { - None => break, // Server dropped connection - Some(Err(e)) => { // WebSocket error - // "Invalid framing or invalid DAG-CBOR encoding are hard errors, - // and the client should drop the entire connection instead of skipping the frame." - // https://atproto.com/specs/event-stream - yield Err(SubscriptionError::Abort(format!("Received invalid frame. Error: {e:?}"))); - break; - } - Some(Ok(Message::Binary(data))) => { - match Frame::try_from(data) { - Ok(Frame::Message { t, data: payload }) => { - match handler.handle_payload(t, payload).await { - Ok(Some(res)) => yield Ok(res), // Payload was successfully handled. - Ok(None) => {}, // Payload was ignored by Handler. - Err(e) => { + fn handle_connection( + mut connection: impl Stream + Unpin, + handler: H, + ) -> impl Stream< + Item = Result, SubscriptionError>, + > { + // Builds a new async stream that will deserialize the packets sent through the + // TCP tunnel and then yield the results processed by the handler back to the caller. + let stream = stream! { + loop { + match connection.next().await { + None => break, // Server dropped connection + Some(Err(e)) => { // WebSocket error + // "Invalid framing or invalid DAG-CBOR encoding are hard errors, + // and the client should drop the entire connection instead of skipping the frame." + // https://atproto.com/specs/event-stream + yield Err(SubscriptionError::Abort(format!("Received invalid frame. Error: {e:?}"))); + break; + } + Some(Ok(Message::Binary(data))) => { + match Frame::try_from(data) { + Ok(Frame::Message { t, data: payload }) => { + match handler.handle_payload(t, payload).await { + Ok(Some(res)) => yield Ok(res), // Payload was successfully handled. + Ok(None) => {}, // Payload was ignored by Handler. + Err(e) => { + // "Invalid framing or invalid DAG-CBOR encoding are hard errors, + // and the client should drop the entire connection instead of skipping the frame." + // https://atproto.com/specs/event-stream + yield Err(SubscriptionError::Abort(format!("Received invalid payload. Error: {e:?}"))); + break; + }, + } + }, + Ok(Frame::Error { data }) => { + yield match serde_ipld_dagcbor::from_reader::(data.as_slice()) { + Ok(e) => Err(SubscriptionError::Other(e)), + Err(e) => Err(SubscriptionError::Unknown(format!("Failed to decode error frame: {e:?}"))), + }; + break; + }, + Err(frames::Error::EmptyPayload(ipld)) => { + // "Invalid framing or invalid DAG-CBOR encoding are hard frames::errors, + // and the client should drop the entire connection instead of skipping the frame." + // https://atproto.com/specs/event-stream + yield Err(SubscriptionError::Abort(format!("Received empty payload for header: {ipld:?}"))); + break; + }, + Err(frames::Error::IpldDecoding(e)) => { // "Invalid framing or invalid DAG-CBOR encoding are hard errors, // and the client should drop the entire connection instead of skipping the frame." // https://atproto.com/specs/event-stream - yield Err(SubscriptionError::Abort(format!("Received invalid payload. Error: {e:?}"))); + yield Err(SubscriptionError::Abort(format!("Received invalid frame. Error: {e:?}"))); break; }, + Err(frames::Error::UnknownFrameType(_)) => { + // "Clients should ignore frames with headers that have unknown op or t values. + // Unknown fields in both headers and payloads should be ignored." + // https://atproto.com/specs/event-stream + }, } - }, - Ok(Frame::Error { data }) => { - yield match serde_ipld_dagcbor::from_reader::(data.as_slice()) { - Ok(e) => Err(SubscriptionError::Other(e)), - Err(e) => Err(SubscriptionError::Unknown(format!("Failed to decode error frame: {e:?}"))), - }; - break; - }, - Err(frames::Error::EmptyPayload(ipld)) => { - // "Invalid framing or invalid DAG-CBOR encoding are hard frames::errors, - // and the client should drop the entire connection instead of skipping the frame." - // https://atproto.com/specs/event-stream - yield Err(SubscriptionError::Abort(format!("Received empty payload for header: {ipld:?}"))); - break; - }, - Err(frames::Error::IpldDecoding(e)) => { - // "Invalid framing or invalid DAG-CBOR encoding are hard errors, - // and the client should drop the entire connection instead of skipping the frame." - // https://atproto.com/specs/event-stream - yield Err(SubscriptionError::Abort(format!("Received invalid frame. Error: {e:?}"))); - break; - }, - Err(frames::Error::UnknownFrameType(_)) => { - // "Clients should ignore frames with headers that have unknown op or t values. - // Unknown fields in both headers and payloads should be ignored." - // https://atproto.com/specs/event-stream - }, + } + _ => {}, // Ignore other message types. } } - _ => {}, // Ignore other message types. - } - } - }; + }; - Box::pin(stream) - } + Box::pin(stream) + } } diff --git a/atrium-xrpc-wss-client/src/subscriptions/repositories/type_defs.rs b/atrium-xrpc-wss-client/src/subscriptions/repositories/type_defs.rs index 193fbd4c..31931b74 100644 --- a/atrium-xrpc-wss-client/src/subscriptions/repositories/type_defs.rs +++ b/atrium-xrpc-wss-client/src/subscriptions/repositories/type_defs.rs @@ -1,30 +1,30 @@ //! This file defines the types used in the Firehose handler. use atrium_xrpc_wss::atrium_api::{ - record::KnownRecord, - types::{ - string::{Datetime, Did}, - CidLink, - }, + record::KnownRecord, + types::{ + string::{Datetime, Did}, + CidLink, + }, }; // region: Commit #[derive(Debug)] pub struct ProcessedCommitData { - pub repo: Did, - pub commit: CidLink, - // `ops` can be `None` if the commit is marked as `too_big`. - pub ops: Option>, - pub blobs: Vec, - pub rev: String, - pub since: Option, - pub time: Datetime, + pub repo: Did, + pub commit: CidLink, + // `ops` can be `None` if the commit is marked as `too_big`. + pub ops: Option>, + pub blobs: Vec, + pub rev: String, + pub since: Option, + pub time: Datetime, } #[derive(Debug)] pub struct Operation { - pub action: String, - pub path: String, - pub record: Option, + pub action: String, + pub path: String, + pub record: Option, } // endregion: Commit diff --git a/atrium-xrpc-wss/src/client/mod.rs b/atrium-xrpc-wss/src/client/mod.rs index 09ce804a..ae2c1d03 100644 --- a/atrium-xrpc-wss/src/client/mod.rs +++ b/atrium-xrpc-wss/src/client/mod.rs @@ -6,22 +6,22 @@ use futures::Stream; pub use xprc_uri::XrpcUri; /// An abstract WSS client. -pub trait WssClient { - /// Send an XRPC request. - /// - /// # Returns - /// [`Result`] - fn connect( - &self, - ) -> impl Future, ConnectionError>> + Send; +pub trait XrpcWssClient { + /// Send an XRPC request. + /// + /// # Returns + /// [`Result`] + fn connect( + &self, + ) -> impl Future, ConnectionError>> + Send; - /// Get the `atproto-proxy` header. - fn atproto_proxy_header(&self) -> impl Future> + Send { - async { None } - } + /// Get the `atproto-proxy` header. + fn atproto_proxy_header(&self) -> impl Future> + Send { + async { None } + } - /// Get the `atproto-accept-labelers` header. - fn atproto_accept_labelers_header(&self) -> impl Future>> + Send { - async { None } - } + /// Get the `atproto-accept-labelers` header. + fn atproto_accept_labelers_header(&self) -> impl Future>> + Send { + async { None } + } } diff --git a/atrium-xrpc-wss/src/client/xprc_uri.rs b/atrium-xrpc-wss/src/client/xprc_uri.rs index 519dc850..89d5a89f 100644 --- a/atrium-xrpc-wss/src/client/xprc_uri.rs +++ b/atrium-xrpc-wss/src/client/xprc_uri.rs @@ -1,16 +1,16 @@ /// The URI for the XRPC `WebSocket` connection. #[derive(Debug, Clone, PartialEq, Eq)] pub struct XrpcUri<'a> { - base_uri: &'a str, - nsid: &'a str, + base_uri: &'a str, + nsid: &'a str, } impl<'a> XrpcUri<'a> { - pub const fn new(base_uri: &'a str, nsid: &'a str) -> Self { - Self { base_uri, nsid } - } + pub const fn new(base_uri: &'a str, nsid: &'a str) -> Self { + Self { base_uri, nsid } + } - pub fn to_uri(&self) -> String { - let XrpcUri { base_uri, nsid } = self; - format!("wss://{base_uri}/xrpc/{nsid}") - } + pub fn to_uri(&self) -> String { + let XrpcUri { base_uri, nsid } = self; + format!("wss://{base_uri}/xrpc/{nsid}") + } } diff --git a/atrium-xrpc-wss/src/lib.rs b/atrium-xrpc-wss/src/lib.rs index c0b81c79..8e08c2a6 100644 --- a/atrium-xrpc-wss/src/lib.rs +++ b/atrium-xrpc-wss/src/lib.rs @@ -1,4 +1,4 @@ pub mod client; pub mod subscriptions; -pub use atrium_api; // Re-export the atrium_api crate \ No newline at end of file +pub use atrium_api; // Re-export the atrium_api crate diff --git a/atrium-xrpc-wss/src/subscriptions/frames/mod.rs b/atrium-xrpc-wss/src/subscriptions/frames/mod.rs index 0ec13090..e332721a 100644 --- a/atrium-xrpc-wss/src/subscriptions/frames/mod.rs +++ b/atrium-xrpc-wss/src/subscriptions/frames/mod.rs @@ -13,74 +13,69 @@ use std::io::Cursor; /// An error type for this crate. #[derive(Debug, thiserror::Error)] pub enum Error { - #[error("Unknown frame type. Header: {0:?}")] - UnknownFrameType(Ipld), - #[error("Payload was empty. Header: {0:?}")] - EmptyPayload(Ipld), - #[error("Ipld Decoding error: {0}")] - IpldDecoding(#[from] serde_ipld_dagcbor::DecodeError), + #[error("Unknown frame type. Header: {0:?}")] + UnknownFrameType(Ipld), + #[error("Payload was empty. Header: {0:?}")] + EmptyPayload(Ipld), + #[error("Ipld Decoding error: {0}")] + IpldDecoding(#[from] serde_ipld_dagcbor::DecodeError), } /// Represents the header of a frame. It's the first [`Ipld`] object in a Binary payload sent by a subscription. #[derive(Debug, Clone, PartialEq, Eq)] enum FrameHeader { - Message { t: String }, - Error, + Message { t: String }, + Error, } impl TryFrom for FrameHeader { - type Error = self::Error; + type Error = self::Error; - fn try_from(header: Ipld) -> Result>::Error> { - if let Ipld::Map(ref map) = header { - if let Some(Ipld::Integer(i)) = map.get("op") { - match i { - 1 => { - if let Some(Ipld::String(s)) = map.get("t") { - return Ok(Self::Message { t: s.to_owned() }); + fn try_from(header: Ipld) -> Result>::Error> { + if let Ipld::Map(ref map) = header { + if let Some(Ipld::Integer(i)) = map.get("op") { + match i { + 1 => { + if let Some(Ipld::String(s)) = map.get("t") { + return Ok(Self::Message { t: s.to_owned() }); + } + } + -1 => return Ok(Self::Error), + _ => {} + } } - } - -1 => return Ok(Self::Error), - _ => {} } - } + Err(Error::UnknownFrameType(header)) } - Err(Error::UnknownFrameType(header)) - } } /// Represents a frame sent by a subscription. It's the second [`Ipld`] object in a Binary payload sent by a subscription. #[derive(Debug, Clone, PartialEq, Eq)] pub enum Frame { - Message { - t: String, - data: Vec, - }, - Error { - data: Vec, - }, + Message { t: String, data: Vec }, + Error { data: Vec }, } impl TryFrom> for Frame { - type Error = self::Error; + type Error = self::Error; - fn try_from(value: Vec) -> Result>>::Error> { - let mut cursor = Cursor::new(value); - let mut deserializer = Deserializer::from_reader(IoReader::new(&mut cursor)); - let header = Deserialize::deserialize(&mut deserializer)?; + fn try_from(value: Vec) -> Result>>::Error> { + let mut cursor = Cursor::new(value); + let mut deserializer = Deserializer::from_reader(IoReader::new(&mut cursor)); + let header = Deserialize::deserialize(&mut deserializer)?; - // Error means the stream did not end (trailing data), which implies a second IPLD (in this case, the payload). - // If the stream ended, the payload is empty, in which case we error. - let data = if deserializer.end().is_err() { - let pos = cursor.position() as usize; - cursor.get_mut().drain(pos..).collect() - } else { - return Err(Error::EmptyPayload(header)); - }; + // Error means the stream did not end (trailing data), which implies a second IPLD (in this case, the payload). + // If the stream ended, the payload is empty, in which case we error. + let data = if deserializer.end().is_err() { + let pos = cursor.position() as usize; + cursor.get_mut().drain(pos..).collect() + } else { + return Err(Error::EmptyPayload(header)); + }; - match FrameHeader::try_from(header)? { - FrameHeader::Message { t } => Ok(Self::Message { t, data }), - FrameHeader::Error => Ok(Self::Error { data }), + match FrameHeader::try_from(header)? { + FrameHeader::Message { t } => Ok(Self::Message { t, data }), + FrameHeader::Error => Ok(Self::Error { data }), + } } - } } diff --git a/atrium-xrpc-wss/src/subscriptions/frames/tests.rs b/atrium-xrpc-wss/src/subscriptions/frames/tests.rs index 3f8a654d..bfab5a9d 100644 --- a/atrium-xrpc-wss/src/subscriptions/frames/tests.rs +++ b/atrium-xrpc-wss/src/subscriptions/frames/tests.rs @@ -1,61 +1,61 @@ use super::*; fn serialized_data(s: &str) -> Vec { - assert!(s.len() % 2 == 0); - let b2u = |b: u8| match b { - b'0'..=b'9' => b - b'0', - b'a'..=b'f' => b - b'a' + 10, - _ => unreachable!(), - }; - s.as_bytes() - .chunks(2) - .map(|b| (b2u(b[0]) << 4) + b2u(b[1])) - .collect() + assert!(s.len() % 2 == 0); + let b2u = |b: u8| match b { + b'0'..=b'9' => b - b'0', + b'a'..=b'f' => b - b'a' + 10, + _ => unreachable!(), + }; + s.as_bytes() + .chunks(2) + .map(|b| (b2u(b[0]) << 4) + b2u(b[1])) + .collect() } #[test] fn deserialize_message_frame_header() { - // {"op": 1, "t": "#commit"} - let data = serialized_data("a2626f700161746723636f6d6d6974"); - let ipld = serde_ipld_dagcbor::from_slice::(&data).expect("failed to deserialize"); - let result = FrameHeader::try_from(ipld); - assert_eq!( - result.expect("failed to deserialize"), - FrameHeader::Message { - t: String::from("#commit") - } - ); + // {"op": 1, "t": "#commit"} + let data = serialized_data("a2626f700161746723636f6d6d6974"); + let ipld = serde_ipld_dagcbor::from_slice::(&data).expect("failed to deserialize"); + let result = FrameHeader::try_from(ipld); + assert_eq!( + result.expect("failed to deserialize"), + FrameHeader::Message { + t: String::from("#commit") + } + ); } #[test] fn deserialize_error_frame_header() { - // {"op": -1} - let data = serialized_data("a1626f7020"); - let ipld = serde_ipld_dagcbor::from_slice::(&data).expect("failed to deserialize"); - let result = FrameHeader::try_from(ipld); - assert_eq!(result.expect("failed to deserialize"), FrameHeader::Error); + // {"op": -1} + let data = serialized_data("a1626f7020"); + let ipld = serde_ipld_dagcbor::from_slice::(&data).expect("failed to deserialize"); + let result = FrameHeader::try_from(ipld); + assert_eq!(result.expect("failed to deserialize"), FrameHeader::Error); } #[test] fn deserialize_invalid_frame_header() { - { - // {"op": 2, "t": "#commit"} - let data = serialized_data("a2626f700261746723636f6d6d6974"); - let ipld = serde_ipld_dagcbor::from_slice::(&data).expect("failed to deserialize"); - let result = FrameHeader::try_from(ipld); - assert_eq!( - result.expect_err("must be failed").to_string(), - "Unknown frame type. Header: {\"op\": 2, \"t\": \"#commit\"}" - ); - } - { - // {"op": -2} - let data = serialized_data("a1626f7021"); - let ipld = serde_ipld_dagcbor::from_slice::(&data).expect("failed to deserialize"); - let result = FrameHeader::try_from(ipld); - assert_eq!( - result.expect_err("must be failed").to_string(), - "Unknown frame type. Header: {\"op\": -2}" - ); - } + { + // {"op": 2, "t": "#commit"} + let data = serialized_data("a2626f700261746723636f6d6d6974"); + let ipld = serde_ipld_dagcbor::from_slice::(&data).expect("failed to deserialize"); + let result = FrameHeader::try_from(ipld); + assert_eq!( + result.expect_err("must be failed").to_string(), + "Unknown frame type. Header: {\"op\": 2, \"t\": \"#commit\"}" + ); + } + { + // {"op": -2} + let data = serialized_data("a1626f7021"); + let ipld = serde_ipld_dagcbor::from_slice::(&data).expect("failed to deserialize"); + let result = FrameHeader::try_from(ipld); + assert_eq!( + result.expect_err("must be failed").to_string(), + "Unknown frame type. Header: {\"op\": -2}" + ); + } } diff --git a/atrium-xrpc-wss/src/subscriptions/handlers/mod.rs b/atrium-xrpc-wss/src/subscriptions/handlers/mod.rs index 3250664f..10de0374 100644 --- a/atrium-xrpc-wss/src/subscriptions/handlers/mod.rs +++ b/atrium-xrpc-wss/src/subscriptions/handlers/mod.rs @@ -1,3 +1,3 @@ use super::{ConnectionHandler, ProcessedPayload}; -pub mod repositories; \ No newline at end of file +pub mod repositories; diff --git a/atrium-xrpc-wss/src/subscriptions/handlers/repositories.rs b/atrium-xrpc-wss/src/subscriptions/handlers/repositories.rs index 0e98cdfa..9a4b0198 100644 --- a/atrium-xrpc-wss/src/subscriptions/handlers/repositories.rs +++ b/atrium-xrpc-wss/src/subscriptions/handlers/repositories.rs @@ -9,25 +9,25 @@ use super::{ConnectionHandler, ProcessedPayload}; /// This type should be used to define [`ConnectionHandler::HandledData`](ConnectionHandler::HandledData) /// for the `com.atproto.sync.subscribeRepos` subscription type. pub type HandledData = ProcessedData< - ::ProcessedCommitData, - ::ProcessedIdentityData, - ::ProcessedAccountData, - ::ProcessedHandleData, - ::ProcessedMigrateData, - ::ProcessedTombstoneData, - ::ProcessedInfoData, + ::ProcessedCommitData, + ::ProcessedIdentityData, + ::ProcessedAccountData, + ::ProcessedHandleData, + ::ProcessedMigrateData, + ::ProcessedTombstoneData, + ::ProcessedInfoData, >; /// Wrapper around all the possible types of processed data. #[derive(Debug)] pub enum ProcessedData { - Commit(C), - Identity(I0), - Account(A), - Handle(H), - Migrate(M), - Tombstone(T), - Info(I1), + Commit(C), + Identity(I0), + Account(A), + Handle(H), + Migrate(M), + Tombstone(T), + Info(I1), } /// A trait that defines a [`ConnectionHandler`] specific to the @@ -40,86 +40,90 @@ pub enum ProcessedData { /// payload they pretend to use. The same goes for the implementations of /// each processing method, as the algorithm may vary. pub trait Handler: ConnectionHandler { - type ProcessedCommitData; - /// Processes a payload of type `#commit`. - fn process_commit( - &self, - payload: subscribe_repos::Commit, - ) -> impl Future< - Output = Result>, Self::HandlingError>, - > { - // Default implementation always returns `None`, meaning the implementation decided to ignore the payload. - async { Ok(None) } - } + type ProcessedCommitData; + /// Processes a payload of type `#commit`. + fn process_commit( + &self, + payload: subscribe_repos::Commit, + ) -> impl Future< + Output = Result>, Self::HandlingError>, + > { + // Default implementation always returns `None`, meaning the implementation decided to ignore the payload. + async { Ok(None) } + } - type ProcessedIdentityData; - /// Processes a payload of type `#identity`. - fn process_identity( - &self, - payload: subscribe_repos::Identity, - ) -> impl Future< - Output = Result>, Self::HandlingError>, - > { - // Default implementation always returns `None`, meaning the implementation decided to ignore the payload. - async { Ok(None) } - } + type ProcessedIdentityData; + /// Processes a payload of type `#identity`. + fn process_identity( + &self, + payload: subscribe_repos::Identity, + ) -> impl Future< + Output = Result>, Self::HandlingError>, + > { + // Default implementation always returns `None`, meaning the implementation decided to ignore the payload. + async { Ok(None) } + } - type ProcessedAccountData; - /// Processes a payload of type `#account`. - fn process_account( - &self, - payload: subscribe_repos::Account, - ) -> impl Future< - Output = Result>, Self::HandlingError>, - > { - // Default implementation always returns `None`, meaning the implementation decided to ignore the payload. - async { Ok(None) } - } + type ProcessedAccountData; + /// Processes a payload of type `#account`. + fn process_account( + &self, + payload: subscribe_repos::Account, + ) -> impl Future< + Output = Result>, Self::HandlingError>, + > { + // Default implementation always returns `None`, meaning the implementation decided to ignore the payload. + async { Ok(None) } + } - type ProcessedHandleData; - /// Processes a payload of type `#handle`. - fn process_handle( - &self, - payload: subscribe_repos::Handle, - ) -> impl Future< - Output = Result>, Self::HandlingError>, - > { - // Default implementation always returns `None`, meaning the implementation decided to ignore the payload. - async { Ok(None) } - } + type ProcessedHandleData; + /// Processes a payload of type `#handle`. + fn process_handle( + &self, + payload: subscribe_repos::Handle, + ) -> impl Future< + Output = Result>, Self::HandlingError>, + > { + // Default implementation always returns `None`, meaning the implementation decided to ignore the payload. + async { Ok(None) } + } - type ProcessedMigrateData; - /// Processes a payload of type `#migrate`. - fn process_migrate( - &self, - payload: subscribe_repos::Migrate, - ) -> impl Future< - Output = Result>, Self::HandlingError>, - > { - // Default implementation always returns `None`, meaning the implementation decided to ignore the payload. - async { Ok(None) } - } + type ProcessedMigrateData; + /// Processes a payload of type `#migrate`. + fn process_migrate( + &self, + payload: subscribe_repos::Migrate, + ) -> impl Future< + Output = Result>, Self::HandlingError>, + > { + // Default implementation always returns `None`, meaning the implementation decided to ignore the payload. + async { Ok(None) } + } - type ProcessedTombstoneData; - /// Processes a payload of type `#tombstone`. - fn process_tombstone( - &self, - payload: subscribe_repos::Tombstone, - ) -> impl Future< - Output = Result>, Self::HandlingError>, - > { - // Default implementation always returns `None`, meaning the implementation decided to ignore the payload. - async { Ok(None) } - } + type ProcessedTombstoneData; + /// Processes a payload of type `#tombstone`. + fn process_tombstone( + &self, + payload: subscribe_repos::Tombstone, + ) -> impl Future< + Output = Result< + Option>, + Self::HandlingError, + >, + > { + // Default implementation always returns `None`, meaning the implementation decided to ignore the payload. + async { Ok(None) } + } - type ProcessedInfoData; - /// Processes a payload of type `#info`. - fn process_info( - &self, - payload: subscribe_repos::Info, - ) -> impl Future>, Self::HandlingError>> - { - // Default implementation always returns `None`, meaning the implementation decided to ignore the payload. - async { Ok(None) } - } + type ProcessedInfoData; + /// Processes a payload of type `#info`. + fn process_info( + &self, + payload: subscribe_repos::Info, + ) -> impl Future< + Output = Result>, Self::HandlingError>, + > { + // Default implementation always returns `None`, meaning the implementation decided to ignore the payload. + async { Ok(None) } + } } diff --git a/atrium-xrpc-wss/src/subscriptions/mod.rs b/atrium-xrpc-wss/src/subscriptions/mod.rs index a3a3db4b..67beca0a 100644 --- a/atrium-xrpc-wss/src/subscriptions/mod.rs +++ b/atrium-xrpc-wss/src/subscriptions/mod.rs @@ -7,25 +7,25 @@ use futures::Stream; /// A trait that defines the connection handler. pub trait ConnectionHandler { - /// The [`Self::HandledData`](ConnectionHandler::HandledData) type should be used to define the returned processed data type. - type HandledData; - /// The [`Self::HandlingError`](ConnectionHandler::HandlingError) type should be used to define the processing error type. - type HandlingError: 'static + Send + Sync + Debug; + /// The [`Self::HandledData`](ConnectionHandler::HandledData) type should be used to define the returned processed data type. + type HandledData; + /// The [`Self::HandlingError`](ConnectionHandler::HandlingError) type should be used to define the processing error type. + type HandlingError: 'static + Send + Sync + Debug; - /// Handles binary data coming from the connection. This function will deserialize the payload body and call the appropriate - /// handler for each payload type. - /// - /// # Returns - /// [`Result>`] like: - /// - `Ok(Some(processedPayload))` where `processedPayload` is [`ProcessedPayload`](ProcessedPayload) - /// if the payload was successfully processed. - /// - `Ok(None)` if the payload was ignored. - /// - `Err(e)` where `e` is [`ConnectionHandler::HandlingError`] if an error occurred while processing the payload. - fn handle_payload( - &self, - t: String, - payload: Vec, - ) -> impl Future>, Self::HandlingError>>; + /// Handles binary data coming from the connection. This function will deserialize the payload body and call the appropriate + /// handler for each payload type. + /// + /// # Returns + /// [`Result>`] like: + /// - `Ok(Some(processedPayload))` where `processedPayload` is [`ProcessedPayload`](ProcessedPayload) + /// if the payload was successfully processed. + /// - `Ok(None)` if the payload was ignored. + /// - `Err(e)` where `e` is [`ConnectionHandler::HandlingError`] if an error occurred while processing the payload. + fn handle_payload( + &self, + t: String, + payload: Vec, + ) -> impl Future>, Self::HandlingError>>; } /// A trait that defines a subscription. @@ -33,31 +33,31 @@ pub trait ConnectionHandler { /// The `ConnectionPayload` type parameter is the type of the payload that will be received through the connection stream. /// The `Error` type parameter is the type of the error that the specific subscription can return, following the lexicon. pub trait Subscription { - /// The `handle_connection` method should be implemented to handle the connection. - /// - /// # Returns - /// A stream of processed payloads. - fn handle_connection( - connection: impl Stream + Unpin, - handler: H, - ) -> impl Stream, SubscriptionError>>; + /// The `handle_connection` method should be implemented to handle the connection. + /// + /// # Returns + /// A stream of processed payloads. + fn handle_connection( + connection: impl Stream + Unpin, + handler: H, + ) -> impl Stream, SubscriptionError>>; } /// This struct represents a processed payload. /// It contains the sequence number (cursor) and the final processed data. pub struct ProcessedPayload { - pub seq: Option, // Might be absent, like in the case of #info. - pub data: Kind, + pub seq: Option, // Might be absent, like in the case of #info. + pub data: Kind, } /// Helper function to convert between payload kinds. impl ProcessedPayload { - pub fn map NewKind>(self, f: F) -> ProcessedPayload { - ProcessedPayload { - seq: self.seq, - data: f(self.data), + pub fn map NewKind>(self, f: F) -> ProcessedPayload { + ProcessedPayload { + seq: self.seq, + data: f(self.data), + } } - } } /// An error type that represents a subscription error. @@ -72,10 +72,10 @@ impl ProcessedPayload { /// This can be used to handle different kinds of errors, following the lexicon. #[derive(Debug, thiserror::Error)] pub enum SubscriptionError { - #[error("Critical Subscription Error: {0}")] - Abort(String), - #[error("Unknown Subscription Error: {0}")] - Unknown(String), - #[error(transparent)] - Other(T), + #[error("Critical Subscription Error: {0}")] + Abort(String), + #[error("Unknown Subscription Error: {0}")] + Unknown(String), + #[error(transparent)] + Other(T), } diff --git a/examples/firehose/src/main.rs b/examples/firehose/src/main.rs index 8a5c8257..8a20e3a9 100644 --- a/examples/firehose/src/main.rs +++ b/examples/firehose/src/main.rs @@ -1,18 +1,18 @@ use anyhow::bail; use atrium_xrpc_wss_client::{ - atrium_xrpc_wss::{ - atrium_api::com::atproto::sync::subscribe_repos::{self, InfoData}, - client::{WssClient, XrpcUri}, - subscriptions::{ - repositories::ProcessedData, - ProcessedPayload, SubscriptionError, + atrium_xrpc_wss::{ + atrium_api::com::atproto::sync::subscribe_repos::{self, InfoData}, + client::{XrpcUri, XrpcWssClient}, + subscriptions::{ + handlers::repositories::ProcessedData, ProcessedPayload, SubscriptionError, + }, }, - }, - subscriptions::repositories::{ - firehose::Firehose, - type_defs::{Operation, ProcessedCommitData}, Repositories, - }, - Error, XrpcWssClient, + subscriptions::repositories::{ + firehose::Firehose, + type_defs::{Operation, ProcessedCommitData}, + Repositories, + }, + DefaultClient, Error, }; use futures::StreamExt; use tokio_tungstenite::tungstenite; @@ -20,127 +20,127 @@ use tokio_tungstenite::tungstenite; /// This example demonstrates how to connect to the ATProto Firehose. #[tokio::main] async fn main() { - // Define the XrpcUri for the subscription. - let xrpc_uri = XrpcUri::new("bsky.network", subscribe_repos::NSID); + // Define the XrpcUri for the subscription. + let xrpc_uri = XrpcUri::new("bsky.network", subscribe_repos::NSID); - // Caching the last cursor is important. - // The API has a backfilling mechanism that allows you to resume from where you stopped. - let mut last_cursor = None; - drop(connect(&mut last_cursor, &xrpc_uri).await); + // Caching the last cursor is important. + // The API has a backfilling mechanism that allows you to resume from where you stopped. + let mut last_cursor = None; + drop(connect(&mut last_cursor, &xrpc_uri).await); } /// Connects to `ATProto` to receive real-time data. async fn connect( - last_cursor: &mut Option, - xrpc_uri: &XrpcUri<'_>, + last_cursor: &mut Option, + xrpc_uri: &XrpcUri<'_>, ) -> Result<(), anyhow::Error> { - // Define the query parameters. In this case, just the cursor. - let params = subscribe_repos::ParametersData { - cursor: *last_cursor, - }; - - // Build a new XRPC WSS Client. - let client = XrpcWssClient::builder() - .xrpc_uri(xrpc_uri.clone()) - .params(params) - .build(); - - // And then we connect to the API. - let connection = match client.connect().await { - Ok(connection) => connection, - Err(Error::Connection(tungstenite::Error::Http(response))) => { - // According to the API documentation, the following status codes are expected and should be treated accordingly: - // 405 Method Not Allowed: Returned to client for non-GET HTTP requests to a stream endpoint. - // 426 Upgrade Required: Returned to client if Upgrade header is not included in a request to a stream endpoint. - // 429 Too Many Requests: Frequently used for rate-limiting. Client may try again after a delay. Support for the Retry-After header is encouraged. - // 500 Internal Server Error: Client may try again after a delay - // 501 Not Implemented: Service does not implement WebSockets or streams, at least for this endpoint. Client should not try again. - // 502 Bad Gateway, 503 Service Unavailable, 504 Gateway Timeout: Client may try again after a delay. - // https://atproto.com/specs/event-stream - bail!("Status Code was: {response:?}") - } - Err(e) => bail!(e), - }; + // Define the query parameters. In this case, just the cursor. + let params = subscribe_repos::ParametersData { + cursor: *last_cursor, + }; - // Builds a new subscription from the connection, using handler provided - // by atrium-xrpc-wss-client, the `Firehose`. - let mut subscription = Repositories::builder() - .connection(connection) - .handler(Firehose) - .build(); + // Build a new XRPC WSS Client. + let client = DefaultClient::builder() + .xrpc_uri(xrpc_uri.clone()) + .params(params) + .build(); - // Receive payloads by calling `StreamExt::next()`. - while let Some(payload) = subscription.next().await { - let data = match payload { - Ok(ProcessedPayload { seq, data }) => { - if let Some(seq) = seq { - *last_cursor = Some(seq); + // And then we connect to the API. + let connection = match client.connect().await { + Ok(connection) => connection, + Err(Error::Connection(tungstenite::Error::Http(response))) => { + // According to the API documentation, the following status codes are expected and should be treated accordingly: + // 405 Method Not Allowed: Returned to client for non-GET HTTP requests to a stream endpoint. + // 426 Upgrade Required: Returned to client if Upgrade header is not included in a request to a stream endpoint. + // 429 Too Many Requests: Frequently used for rate-limiting. Client may try again after a delay. Support for the Retry-After header is encouraged. + // 500 Internal Server Error: Client may try again after a delay + // 501 Not Implemented: Service does not implement WebSockets or streams, at least for this endpoint. Client should not try again. + // 502 Bad Gateway, 503 Service Unavailable, 504 Gateway Timeout: Client may try again after a delay. + // https://atproto.com/specs/event-stream + bail!("Status Code was: {response:?}") } - data - } - Err(SubscriptionError::Abort(reason)) => { - // This could mean multiple things, all of which are critical errors that require - // immediate termination of connection. - eprintln!("Aborted: {reason}"); - *last_cursor = None; - break; - } - Err(e) => { - // Errors such as `FutureCursor` and `ConsumerTooSlow` can be dealt with here. - eprintln!("{e:?}"); - *last_cursor = None; - break; - } + Err(e) => bail!(e), }; - match data { - ProcessedData::Commit(data) => beauty_print_commit(data), - ProcessedData::Info(InfoData { message, name }) => { - println!("Received info. Message: {message:?}; Name: {name}."); - } - _ => { /* Ignored */ } - }; - } + // Builds a new subscription from the connection, using handler provided + // by atrium-xrpc-wss-client, the `Firehose`. + let mut subscription = Repositories::builder() + .connection(connection) + .handler(Firehose) + .build(); + + // Receive payloads by calling `StreamExt::next()`. + while let Some(payload) = subscription.next().await { + let data = match payload { + Ok(ProcessedPayload { seq, data }) => { + if let Some(seq) = seq { + *last_cursor = Some(seq); + } + data + } + Err(SubscriptionError::Abort(reason)) => { + // This could mean multiple things, all of which are critical errors that require + // immediate termination of connection. + eprintln!("Aborted: {reason}"); + *last_cursor = None; + break; + } + Err(e) => { + // Errors such as `FutureCursor` and `ConsumerTooSlow` can be dealt with here. + eprintln!("{e:?}"); + *last_cursor = None; + break; + } + }; - Ok(()) + match data { + ProcessedData::Commit(data) => beauty_print_commit(data), + ProcessedData::Info(InfoData { message, name }) => { + println!("Received info. Message: {message:?}; Name: {name}."); + } + _ => { /* Ignored */ } + }; + } + + Ok(()) } fn beauty_print_commit(data: ProcessedCommitData) { - let ProcessedCommitData { - repo, commit, ops, .. - } = data; - if let Some(ops) = ops { - for r in ops { - let Operation { - action, - path, - record, - } = r; - let print = format!( - "\n\n\n################################# {} ##################################\n\ + let ProcessedCommitData { + repo, commit, ops, .. + } = data; + if let Some(ops) = ops { + for r in ops { + let Operation { + action, + path, + record, + } = r; + let print = format!( + "\n\n\n################################# {} ##################################\n\ - Repository (User DID): {}\n\ - Commit CID: {}\n\ - Path: {path}\n\ - Flagged as \"too big\"? ", - action.to_uppercase(), - repo.as_str(), - commit.0, - ); - // Record is only `None` when the commit was flagged as "too big". - if let Some(record) = record { - println!( - "{}No\n\ + action.to_uppercase(), + repo.as_str(), + commit.0, + ); + // Record is only `None` when the commit was flagged as "too big". + if let Some(record) = record { + println!( + "{}No\n\ //-------------------------------- Record Info -------------------------------//\n\n\ {:?}", - print, record - ); - } else { - println!( - "{}Yes\n\ + print, record + ); + } else { + println!( + "{}Yes\n\ //---------------------------------------------------------------------------//\n\n", - print - ); - } + print + ); + } + } } - } } From 659e8190c8e5418679321955a091afdd61adfd59 Mon Sep 17 00:00:00 2001 From: Elaina <48662592+oestradiol@users.noreply.github.com> Date: Fri, 13 Sep 2024 00:32:48 -0300 Subject: [PATCH 3/7] Renaming crates and client + trait --- Cargo.lock | 64 +++++++++---------- Cargo.toml | 8 +-- README.md | 8 +-- .../.gitignore | 0 .../CHANGELOG.md | 0 .../Cargo.toml | 8 +-- .../README.md | 0 .../src/client.rs | 10 +-- atrium-streams-client/src/lib.rs | 6 ++ .../src/subscriptions/mod.rs | 0 .../subscriptions/repositories/firehose.rs | 2 +- .../src/subscriptions/repositories/mod.rs | 2 +- .../subscriptions/repositories/type_defs.rs | 2 +- .../.gitignore | 0 .../CHANGELOG.md | 0 .../Cargo.toml | 6 +- {atrium-xrpc-wss => atrium-streams}/README.md | 0 .../src/client/mod.rs | 2 +- .../src/client/xprc_uri.rs | 0 .../src/lib.rs | 0 .../src/subscriptions/frames/mod.rs | 0 .../src/subscriptions/frames/tests.rs | 0 .../src/subscriptions/handlers/mod.rs | 0 .../subscriptions/handlers/repositories.rs | 0 .../src/subscriptions/mod.rs | 0 atrium-xrpc-wss-client/src/lib.rs | 6 -- examples/firehose/Cargo.toml | 2 +- examples/firehose/src/main.rs | 12 ++-- 28 files changed, 69 insertions(+), 69 deletions(-) rename {atrium-xrpc-wss-client => atrium-streams-client}/.gitignore (100%) rename {atrium-xrpc-wss-client => atrium-streams-client}/CHANGELOG.md (100%) rename {atrium-xrpc-wss-client => atrium-streams-client}/Cargo.toml (74%) rename {atrium-xrpc-wss-client => atrium-streams-client}/README.md (100%) rename {atrium-xrpc-wss-client => atrium-streams-client}/src/client.rs (89%) create mode 100644 atrium-streams-client/src/lib.rs rename {atrium-xrpc-wss-client => atrium-streams-client}/src/subscriptions/mod.rs (100%) rename {atrium-xrpc-wss-client => atrium-streams-client}/src/subscriptions/repositories/firehose.rs (99%) rename {atrium-xrpc-wss-client => atrium-streams-client}/src/subscriptions/repositories/mod.rs (99%) rename {atrium-xrpc-wss-client => atrium-streams-client}/src/subscriptions/repositories/type_defs.rs (96%) rename {atrium-xrpc-wss => atrium-streams}/.gitignore (100%) rename {atrium-xrpc-wss => atrium-streams}/CHANGELOG.md (100%) rename {atrium-xrpc-wss => atrium-streams}/Cargo.toml (75%) rename {atrium-xrpc-wss => atrium-streams}/README.md (100%) rename {atrium-xrpc-wss => atrium-streams}/src/client/mod.rs (90%) rename {atrium-xrpc-wss => atrium-streams}/src/client/xprc_uri.rs (100%) rename {atrium-xrpc-wss => atrium-streams}/src/lib.rs (100%) rename {atrium-xrpc-wss => atrium-streams}/src/subscriptions/frames/mod.rs (100%) rename {atrium-xrpc-wss => atrium-streams}/src/subscriptions/frames/tests.rs (100%) rename {atrium-xrpc-wss => atrium-streams}/src/subscriptions/handlers/mod.rs (100%) rename {atrium-xrpc-wss => atrium-streams}/src/subscriptions/handlers/repositories.rs (100%) rename {atrium-xrpc-wss => atrium-streams}/src/subscriptions/mod.rs (100%) delete mode 100644 atrium-xrpc-wss-client/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index b9d24778..cd61aa38 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -200,6 +200,38 @@ dependencies = [ "thiserror", ] +[[package]] +name = "atrium-streams" +version = "0.1.0" +dependencies = [ + "atrium-api", + "cbor4ii", + "futures", + "ipld-core", + "serde", + "serde_ipld_dagcbor", + "thiserror", +] + +[[package]] +name = "atrium-streams-client" +version = "0.1.0" +dependencies = [ + "async-stream", + "atrium-streams", + "atrium-xrpc", + "bon", + "futures", + "ipld-core", + "rs-car", + "serde", + "serde_html_form", + "serde_ipld_dagcbor", + "thiserror", + "tokio", + "tokio-tungstenite", +] + [[package]] name = "atrium-xrpc" version = "0.11.3" @@ -231,38 +263,6 @@ dependencies = [ "wasm-bindgen-test", ] -[[package]] -name = "atrium-xrpc-wss" -version = "0.1.0" -dependencies = [ - "atrium-api", - "cbor4ii", - "futures", - "ipld-core", - "serde", - "serde_ipld_dagcbor", - "thiserror", -] - -[[package]] -name = "atrium-xrpc-wss-client" -version = "0.1.0" -dependencies = [ - "async-stream", - "atrium-xrpc", - "atrium-xrpc-wss", - "bon", - "futures", - "ipld-core", - "rs-car", - "serde", - "serde_html_form", - "serde_ipld_dagcbor", - "thiserror", - "tokio", - "tokio-tungstenite", -] - [[package]] name = "autocfg" version = "1.3.0" diff --git a/Cargo.toml b/Cargo.toml index 6f9c9698..6dfe6d71 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,8 +4,8 @@ members = [ "atrium-crypto", "atrium-xrpc", "atrium-xrpc-client", - "atrium-xrpc-wss", - "atrium-xrpc-wss-client", + "atrium-streams", + "atrium-streams-client", "bsky-cli", "bsky-sdk", ] @@ -28,8 +28,8 @@ keywords = ["atproto", "bluesky"] atrium-api = { version = "0.24.4", path = "atrium-api" } atrium-xrpc = { version = "0.11.3", path = "atrium-xrpc" } atrium-xrpc-client = { version = "0.5.6", path = "atrium-xrpc-client" } -atrium-xrpc-wss = { version = "0.1.0", path = "atrium-xrpc-wss" } -atrium-xrpc-wss-client = { version = "0.1.0", path = "atrium-xrpc-wss-client" } +atrium-streams = { version = "0.1.0", path = "atrium-streams" } +atrium-streams-client = { version = "0.1.0", path = "atrium-streams-client" } bsky-sdk = { version = "0.1.9", path = "bsky-sdk" } # async in traits diff --git a/README.md b/README.md index be5c4574..4f2d7da6 100644 --- a/README.md +++ b/README.md @@ -31,13 +31,13 @@ Definitions for XRPC request/response, and their associated errors. A library provides clients that implement the `XrpcClient` defined in [atrium-xrpc](./atrium-xrpc/) -### [`atrium-xrpc-wss`](./atrium-xrpc-wss/) +### [`atrium-streams`](./atrium-streams/) -Definitions for traits, types and utilities for dealing with WebSocket XRPC subscriptions. (WIP) +Definitions for traits, types and utilities for dealing with event stream subscriptions. (WIP) -### [`atrium-xrpc-wss-client`](./atrium-xrpc-wss-client/) +### [`atrium-streams-client`](./atrium-streams-client/) -A library that provides default implementations of the `XrpcWssClient`, `Handlers` and `Subscription` defined in [atrium-xrpc-wss](./atrium-xrpc-wss/) for interacting with the variety of subscriptions in ATProto (WIP) +A library that provides default implementations of the `EventStreamClient`, `Handlers` and `Subscription` defined in [atrium-streams](./atrium-streams/) for interacting with the variety of subscriptions in ATProto (WIP) ### [`bsky-sdk`](./bsky-sdk/) diff --git a/atrium-xrpc-wss-client/.gitignore b/atrium-streams-client/.gitignore similarity index 100% rename from atrium-xrpc-wss-client/.gitignore rename to atrium-streams-client/.gitignore diff --git a/atrium-xrpc-wss-client/CHANGELOG.md b/atrium-streams-client/CHANGELOG.md similarity index 100% rename from atrium-xrpc-wss-client/CHANGELOG.md rename to atrium-streams-client/CHANGELOG.md diff --git a/atrium-xrpc-wss-client/Cargo.toml b/atrium-streams-client/Cargo.toml similarity index 74% rename from atrium-xrpc-wss-client/Cargo.toml rename to atrium-streams-client/Cargo.toml index cdb6306b..3b498ab0 100644 --- a/atrium-xrpc-wss-client/Cargo.toml +++ b/atrium-streams-client/Cargo.toml @@ -1,11 +1,11 @@ [package] -name = "atrium-xrpc-wss-client" +name = "atrium-streams-client" version = "0.1.0" authors = ["Elaina <17bestradiol@proton.me>"] edition.workspace = true rust-version.workspace = true -description = "XRPC Websocket Client library for AT Protocol (Bluesky)" -documentation = "https://docs.rs/atrium-xrpc-wss-client" +description = "Event Streams Client library for AT Protocol (Bluesky)" +documentation = "https://docs.rs/atrium-streams-client" readme = "README.md" repository.workspace = true license.workspace = true @@ -13,7 +13,7 @@ keywords.workspace = true [dependencies] atrium-xrpc.workspace = true -atrium-xrpc-wss.workspace = true +atrium-streams.workspace = true futures.workspace = true ipld-core.workspace = true async-stream.workspace = true diff --git a/atrium-xrpc-wss-client/README.md b/atrium-streams-client/README.md similarity index 100% rename from atrium-xrpc-wss-client/README.md rename to atrium-streams-client/README.md diff --git a/atrium-xrpc-wss-client/src/client.rs b/atrium-streams-client/src/client.rs similarity index 89% rename from atrium-xrpc-wss-client/src/client.rs rename to atrium-streams-client/src/client.rs index a6deb84e..67ea9021 100644 --- a/atrium-xrpc-wss-client/src/client.rs +++ b/atrium-streams-client/src/client.rs @@ -1,5 +1,5 @@ //! This file provides a client for the `ATProto` XRPC over WSS protocol. -//! It implements the [`WssClient`] trait for the [`XrpcWssClient`] struct. +//! It implements the [`EventStreamClient`] trait for the [`WssClient`] struct. use std::str::FromStr; @@ -18,7 +18,7 @@ use tokio_tungstenite::{ MaybeTlsStream, WebSocketStream, }; -use atrium_xrpc_wss::client::{XrpcUri, XrpcWssClient}; +use atrium_streams::client::{EventStreamClient, XrpcUri}; /// An enum of possible error kinds for this crate. #[derive(thiserror::Error, Debug)] @@ -32,14 +32,14 @@ pub enum Error { } #[derive(Builder)] -pub struct DefaultClient<'a, P: Serialize> { +pub struct WssClient<'a, P: Serialize> { xrpc_uri: XrpcUri<'a>, params: Option

, } type StreamKind = WebSocketStream>; -impl XrpcWssClient<::Item, Error> - for DefaultClient<'_, P> +impl EventStreamClient<::Item, Error> + for WssClient<'_, P> { async fn connect(&self) -> Result::Item>, Error> { let Self { xrpc_uri, params } = self; diff --git a/atrium-streams-client/src/lib.rs b/atrium-streams-client/src/lib.rs new file mode 100644 index 00000000..f0ee3e7b --- /dev/null +++ b/atrium-streams-client/src/lib.rs @@ -0,0 +1,6 @@ +mod client; +pub use client::{Error, WssClient}; + +pub mod subscriptions; + +pub use atrium_streams; // Re-export the atrium_streams crate diff --git a/atrium-xrpc-wss-client/src/subscriptions/mod.rs b/atrium-streams-client/src/subscriptions/mod.rs similarity index 100% rename from atrium-xrpc-wss-client/src/subscriptions/mod.rs rename to atrium-streams-client/src/subscriptions/mod.rs diff --git a/atrium-xrpc-wss-client/src/subscriptions/repositories/firehose.rs b/atrium-streams-client/src/subscriptions/repositories/firehose.rs similarity index 99% rename from atrium-xrpc-wss-client/src/subscriptions/repositories/firehose.rs rename to atrium-streams-client/src/subscriptions/repositories/firehose.rs index d1bba6ab..c52f4b4d 100644 --- a/atrium-xrpc-wss-client/src/subscriptions/repositories/firehose.rs +++ b/atrium-streams-client/src/subscriptions/repositories/firehose.rs @@ -4,7 +4,7 @@ use futures::io::Cursor as FutCursor; use ipld_core::cid::Cid; use super::type_defs::{self, Operation}; -use atrium_xrpc_wss::{ +use atrium_streams::{ atrium_api::{ com::atproto::sync::subscribe_repos::{self, CommitData, InfoData, RepoOpData}, record::KnownRecord, diff --git a/atrium-xrpc-wss-client/src/subscriptions/repositories/mod.rs b/atrium-streams-client/src/subscriptions/repositories/mod.rs similarity index 99% rename from atrium-xrpc-wss-client/src/subscriptions/repositories/mod.rs rename to atrium-streams-client/src/subscriptions/repositories/mod.rs index 7ed7e68c..289e3898 100644 --- a/atrium-xrpc-wss-client/src/subscriptions/repositories/mod.rs +++ b/atrium-streams-client/src/subscriptions/repositories/mod.rs @@ -8,7 +8,7 @@ use bon::bon; use futures::{Stream, StreamExt}; use tokio_tungstenite::tungstenite::Message; -use atrium_xrpc_wss::{ +use atrium_streams::{ atrium_api::com::atproto::sync::subscribe_repos, subscriptions::{ frames::{self, Frame}, diff --git a/atrium-xrpc-wss-client/src/subscriptions/repositories/type_defs.rs b/atrium-streams-client/src/subscriptions/repositories/type_defs.rs similarity index 96% rename from atrium-xrpc-wss-client/src/subscriptions/repositories/type_defs.rs rename to atrium-streams-client/src/subscriptions/repositories/type_defs.rs index 31931b74..71408851 100644 --- a/atrium-xrpc-wss-client/src/subscriptions/repositories/type_defs.rs +++ b/atrium-streams-client/src/subscriptions/repositories/type_defs.rs @@ -1,6 +1,6 @@ //! This file defines the types used in the Firehose handler. -use atrium_xrpc_wss::atrium_api::{ +use atrium_streams::atrium_api::{ record::KnownRecord, types::{ string::{Datetime, Did}, diff --git a/atrium-xrpc-wss/.gitignore b/atrium-streams/.gitignore similarity index 100% rename from atrium-xrpc-wss/.gitignore rename to atrium-streams/.gitignore diff --git a/atrium-xrpc-wss/CHANGELOG.md b/atrium-streams/CHANGELOG.md similarity index 100% rename from atrium-xrpc-wss/CHANGELOG.md rename to atrium-streams/CHANGELOG.md diff --git a/atrium-xrpc-wss/Cargo.toml b/atrium-streams/Cargo.toml similarity index 75% rename from atrium-xrpc-wss/Cargo.toml rename to atrium-streams/Cargo.toml index d3d760fb..d1f7453d 100644 --- a/atrium-xrpc-wss/Cargo.toml +++ b/atrium-streams/Cargo.toml @@ -1,11 +1,11 @@ [package] -name = "atrium-xrpc-wss" +name = "atrium-streams" version = "0.1.0" authors = ["Elaina <17bestradiol@proton.me>"] edition.workspace = true rust-version.workspace = true -description = "XRPC Websocket library for AT Protocol (Bluesky)" -documentation = "https://docs.rs/atrium-xrpc-wss" +description = "Event Streams library for AT Protocol (Bluesky)" +documentation = "https://docs.rs/atrium-streams" readme = "README.md" repository.workspace = true license.workspace = true diff --git a/atrium-xrpc-wss/README.md b/atrium-streams/README.md similarity index 100% rename from atrium-xrpc-wss/README.md rename to atrium-streams/README.md diff --git a/atrium-xrpc-wss/src/client/mod.rs b/atrium-streams/src/client/mod.rs similarity index 90% rename from atrium-xrpc-wss/src/client/mod.rs rename to atrium-streams/src/client/mod.rs index ae2c1d03..3248a3d7 100644 --- a/atrium-xrpc-wss/src/client/mod.rs +++ b/atrium-streams/src/client/mod.rs @@ -6,7 +6,7 @@ use futures::Stream; pub use xprc_uri::XrpcUri; /// An abstract WSS client. -pub trait XrpcWssClient { +pub trait EventStreamClient { /// Send an XRPC request. /// /// # Returns diff --git a/atrium-xrpc-wss/src/client/xprc_uri.rs b/atrium-streams/src/client/xprc_uri.rs similarity index 100% rename from atrium-xrpc-wss/src/client/xprc_uri.rs rename to atrium-streams/src/client/xprc_uri.rs diff --git a/atrium-xrpc-wss/src/lib.rs b/atrium-streams/src/lib.rs similarity index 100% rename from atrium-xrpc-wss/src/lib.rs rename to atrium-streams/src/lib.rs diff --git a/atrium-xrpc-wss/src/subscriptions/frames/mod.rs b/atrium-streams/src/subscriptions/frames/mod.rs similarity index 100% rename from atrium-xrpc-wss/src/subscriptions/frames/mod.rs rename to atrium-streams/src/subscriptions/frames/mod.rs diff --git a/atrium-xrpc-wss/src/subscriptions/frames/tests.rs b/atrium-streams/src/subscriptions/frames/tests.rs similarity index 100% rename from atrium-xrpc-wss/src/subscriptions/frames/tests.rs rename to atrium-streams/src/subscriptions/frames/tests.rs diff --git a/atrium-xrpc-wss/src/subscriptions/handlers/mod.rs b/atrium-streams/src/subscriptions/handlers/mod.rs similarity index 100% rename from atrium-xrpc-wss/src/subscriptions/handlers/mod.rs rename to atrium-streams/src/subscriptions/handlers/mod.rs diff --git a/atrium-xrpc-wss/src/subscriptions/handlers/repositories.rs b/atrium-streams/src/subscriptions/handlers/repositories.rs similarity index 100% rename from atrium-xrpc-wss/src/subscriptions/handlers/repositories.rs rename to atrium-streams/src/subscriptions/handlers/repositories.rs diff --git a/atrium-xrpc-wss/src/subscriptions/mod.rs b/atrium-streams/src/subscriptions/mod.rs similarity index 100% rename from atrium-xrpc-wss/src/subscriptions/mod.rs rename to atrium-streams/src/subscriptions/mod.rs diff --git a/atrium-xrpc-wss-client/src/lib.rs b/atrium-xrpc-wss-client/src/lib.rs deleted file mode 100644 index b1159b29..00000000 --- a/atrium-xrpc-wss-client/src/lib.rs +++ /dev/null @@ -1,6 +0,0 @@ -mod client; -pub use client::{DefaultClient, Error}; - -pub mod subscriptions; - -pub use atrium_xrpc_wss; // Re-export the atrium_xrpc_wss crate diff --git a/examples/firehose/Cargo.toml b/examples/firehose/Cargo.toml index 4c2cf0d8..d9fceaca 100644 --- a/examples/firehose/Cargo.toml +++ b/examples/firehose/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" [dependencies] anyhow = "1.0.86" -atrium-xrpc-wss-client = { path = "../../atrium-xrpc-wss-client" } +atrium-streams-client = { path = "../../atrium-streams-client" } futures = "0.3.30" tokio-tungstenite = { version = "0.21.0", features = ["native-tls"] } tokio = { version = "1.36.0", features = ["full"] } \ No newline at end of file diff --git a/examples/firehose/src/main.rs b/examples/firehose/src/main.rs index 8a20e3a9..51591d9e 100644 --- a/examples/firehose/src/main.rs +++ b/examples/firehose/src/main.rs @@ -1,8 +1,8 @@ use anyhow::bail; -use atrium_xrpc_wss_client::{ - atrium_xrpc_wss::{ +use atrium_streams_client::{ + atrium_streams::{ atrium_api::com::atproto::sync::subscribe_repos::{self, InfoData}, - client::{XrpcUri, XrpcWssClient}, + client::{XrpcUri, EventStreamClient}, subscriptions::{ handlers::repositories::ProcessedData, ProcessedPayload, SubscriptionError, }, @@ -12,7 +12,7 @@ use atrium_xrpc_wss_client::{ type_defs::{Operation, ProcessedCommitData}, Repositories, }, - DefaultClient, Error, + WssClient, Error, }; use futures::StreamExt; use tokio_tungstenite::tungstenite; @@ -40,7 +40,7 @@ async fn connect( }; // Build a new XRPC WSS Client. - let client = DefaultClient::builder() + let client = WssClient::builder() .xrpc_uri(xrpc_uri.clone()) .params(params) .build(); @@ -63,7 +63,7 @@ async fn connect( }; // Builds a new subscription from the connection, using handler provided - // by atrium-xrpc-wss-client, the `Firehose`. + // by atrium-streams-client, the `Firehose`. let mut subscription = Repositories::builder() .connection(connection) .handler(Firehose) From 575a6c84d92d77d413cd1df60112575059af2dce Mon Sep 17 00:00:00 2001 From: Elaina <48662592+oestradiol@users.noreply.github.com> Date: Fri, 13 Sep 2024 16:47:36 -0300 Subject: [PATCH 4/7] Implementing what was missing - Firehose --- .../subscriptions/repositories/firehose.rs | 186 ++++++++++++++---- .../subscriptions/repositories/type_defs.rs | 32 ++- examples/firehose/src/main.rs | 11 +- 3 files changed, 181 insertions(+), 48 deletions(-) diff --git a/atrium-streams-client/src/subscriptions/repositories/firehose.rs b/atrium-streams-client/src/subscriptions/repositories/firehose.rs index c52f4b4d..e061f617 100644 --- a/atrium-streams-client/src/subscriptions/repositories/firehose.rs +++ b/atrium-streams-client/src/subscriptions/repositories/firehose.rs @@ -6,7 +6,10 @@ use ipld_core::cid::Cid; use super::type_defs::{self, Operation}; use atrium_streams::{ atrium_api::{ - com::atproto::sync::subscribe_repos::{self, CommitData, InfoData, RepoOpData}, + com::atproto::sync::subscribe_repos::{ + self, AccountData, CommitData, HandleData, IdentityData, InfoData, MigrateData, + RepoOpData, TombstoneData, + }, record::KnownRecord, types::Object, }, @@ -25,7 +28,23 @@ pub enum HandlingError { IpldDecoding(#[from] serde_ipld_dagcbor::DecodeError), } -pub struct Firehose; +#[derive(bon::Builder)] +pub struct Firehose { + #[builder(default)] + enable_commit: bool, + #[builder(default)] + enable_identity: bool, + #[builder(default)] + enable_account: bool, + #[builder(default)] + enable_handle: bool, + #[builder(default)] + enable_migrate: bool, + #[builder(default)] + enable_tombstone: bool, + #[builder(default)] + enable_info: bool, +} impl ConnectionHandler for Firehose { type HandledData = HandledData; type HandlingError = self::HandlingError; @@ -36,39 +55,74 @@ impl ConnectionHandler for Firehose { payload: Vec, ) -> Result>, Self::HandlingError> { let res = match t.as_str() { - "#commit" => self - .process_commit(serde_ipld_dagcbor::from_reader(payload.as_slice())?) - .await? - .map(|data| data.map(ProcessedData::Commit)), - "#identity" => self - .process_identity(serde_ipld_dagcbor::from_reader(payload.as_slice())?) - .await? - .map(|data| data.map(ProcessedData::Identity)), - "#account" => self - .process_account(serde_ipld_dagcbor::from_reader(payload.as_slice())?) - .await? - .map(|data| data.map(ProcessedData::Account)), - "#handle" => self - .process_handle(serde_ipld_dagcbor::from_reader(payload.as_slice())?) - .await? - .map(|data| data.map(ProcessedData::Handle)), - "#migrate" => self - .process_migrate(serde_ipld_dagcbor::from_reader(payload.as_slice())?) - .await? - .map(|data| data.map(ProcessedData::Migrate)), - "#tombstone" => self - .process_tombstone(serde_ipld_dagcbor::from_reader(payload.as_slice())?) - .await? - .map(|data| data.map(ProcessedData::Tombstone)), - "#info" => self - .process_info(serde_ipld_dagcbor::from_reader(payload.as_slice())?) - .await? - .map(|data| data.map(ProcessedData::Info)), + "#commit" => { + if self.enable_commit { + self.process_commit(serde_ipld_dagcbor::from_reader(payload.as_slice())?) + .await? + .map(|data| data.map(ProcessedData::Commit)) + } else { + None + } + } + "#identity" => { + if self.enable_identity { + self.process_identity(serde_ipld_dagcbor::from_reader(payload.as_slice())?) + .await? + .map(|data| data.map(ProcessedData::Identity)) + } else { + None + } + } + "#account" => { + if self.enable_account { + self.process_account(serde_ipld_dagcbor::from_reader(payload.as_slice())?) + .await? + .map(|data| data.map(ProcessedData::Account)) + } else { + None + } + } + "#handle" => { + if self.enable_handle { + self.process_handle(serde_ipld_dagcbor::from_reader(payload.as_slice())?) + .await? + .map(|data| data.map(ProcessedData::Handle)) + } else { + None + } + } + "#migrate" => { + if self.enable_migrate { + self.process_migrate(serde_ipld_dagcbor::from_reader(payload.as_slice())?) + .await? + .map(|data| data.map(ProcessedData::Migrate)) + } else { + None + } + } + "#tombstone" => { + if self.enable_tombstone { + self.process_tombstone(serde_ipld_dagcbor::from_reader(payload.as_slice())?) + .await? + .map(|data| data.map(ProcessedData::Tombstone)) + } else { + None + } + } + "#info" => { + if self.enable_info { + self.process_info(serde_ipld_dagcbor::from_reader(payload.as_slice())?) + .await? + .map(|data| data.map(ProcessedData::Info)) + } else { + None + } + } _ => { // "Clients should ignore frames with headers that have unknown op or t values. // Unknown fields in both headers and payloads should be ignored." // https://atproto.com/specs/event-stream - return Ok(None); + None } }; @@ -133,41 +187,91 @@ impl Handler for Firehose { type ProcessedIdentityData = type_defs::ProcessedIdentityData; async fn process_identity( &self, - _payload: subscribe_repos::Identity, + payload: subscribe_repos::Identity, ) -> Result>, Self::HandlingError> { - Ok(None) // TODO: Implement + let IdentityData { + did, + handle, + seq, + time, + } = payload.data; + Ok(Some(ProcessedPayload { + seq: Some(seq), + data: Self::ProcessedIdentityData { did, handle, time }, + })) } type ProcessedAccountData = type_defs::ProcessedAccountData; async fn process_account( &self, - _payload: subscribe_repos::Account, + payload: subscribe_repos::Account, ) -> Result>, Self::HandlingError> { - Ok(None) // TODO: Implement + let AccountData { + did, + seq, + time, + active, + status, + } = payload.data; + Ok(Some(ProcessedPayload { + seq: Some(seq), + data: Self::ProcessedAccountData { + did, + active, + status, + time, + }, + })) } type ProcessedHandleData = type_defs::ProcessedHandleData; async fn process_handle( &self, - _payload: subscribe_repos::Handle, + payload: subscribe_repos::Handle, ) -> Result>, Self::HandlingError> { - Ok(None) // TODO: Implement + let HandleData { + did, + handle, + seq, + time, + } = payload.data; + Ok(Some(ProcessedPayload { + seq: Some(seq), + data: Self::ProcessedHandleData { did, handle, time }, + })) } type ProcessedMigrateData = type_defs::ProcessedMigrateData; async fn process_migrate( &self, - _payload: subscribe_repos::Migrate, + payload: subscribe_repos::Migrate, ) -> Result>, Self::HandlingError> { - Ok(None) // TODO: Implement + let MigrateData { + did, + migrate_to, + seq, + time, + } = payload.data; + Ok(Some(ProcessedPayload { + seq: Some(seq), + data: Self::ProcessedMigrateData { + did, + migrate_to, + time, + }, + })) } type ProcessedTombstoneData = type_defs::ProcessedTombstoneData; async fn process_tombstone( &self, - _payload: subscribe_repos::Tombstone, + payload: subscribe_repos::Tombstone, ) -> Result>, Self::HandlingError> { - Ok(None) // TODO: Implement + let TombstoneData { did, seq, time } = payload.data; + Ok(Some(ProcessedPayload { + seq: Some(seq), + data: Self::ProcessedTombstoneData { did, time }, + })) } type ProcessedInfoData = InfoData; diff --git a/atrium-streams-client/src/subscriptions/repositories/type_defs.rs b/atrium-streams-client/src/subscriptions/repositories/type_defs.rs index 71408851..ce9b272c 100644 --- a/atrium-streams-client/src/subscriptions/repositories/type_defs.rs +++ b/atrium-streams-client/src/subscriptions/repositories/type_defs.rs @@ -3,7 +3,7 @@ use atrium_streams::atrium_api::{ record::KnownRecord, types::{ - string::{Datetime, Did}, + string::{Datetime, Did, Handle}, CidLink, }, }; @@ -30,25 +30,45 @@ pub struct Operation { // region: Identity #[derive(Debug)] -pub struct ProcessedIdentityData {} +pub struct ProcessedIdentityData { + pub did: Did, + pub handle: Option, + pub time: Datetime, +} // endregion: Identity // region: Account #[derive(Debug)] -pub struct ProcessedAccountData {} +pub struct ProcessedAccountData { + pub did: Did, + pub active: bool, + pub status: Option, + pub time: Datetime, +} // endregion: Account // region: Handle #[derive(Debug)] -pub struct ProcessedHandleData {} +pub struct ProcessedHandleData { + pub did: Did, + pub handle: Handle, + pub time: Datetime, +} // endregion: Handle // region: Migrate #[derive(Debug)] -pub struct ProcessedMigrateData {} +pub struct ProcessedMigrateData { + pub did: Did, + pub migrate_to: Option, + pub time: Datetime, +} // endregion: Migrate // region: Tombstone #[derive(Debug)] -pub struct ProcessedTombstoneData {} +pub struct ProcessedTombstoneData { + pub did: Did, + pub time: Datetime, +} // endregion: Tombstone diff --git a/examples/firehose/src/main.rs b/examples/firehose/src/main.rs index 51591d9e..8334f01c 100644 --- a/examples/firehose/src/main.rs +++ b/examples/firehose/src/main.rs @@ -62,11 +62,20 @@ async fn connect( Err(e) => bail!(e), }; + // Builds the subscription handler + let firehose = Firehose::builder() + // You can enable or disable specific events, and every event is disabled by default. + // That way they don't get unnecessarily processed and you save up resources. + // Enable only the ones you plan to use. + .enable_commit(true) + .enable_info(true) + .build(); + // Builds a new subscription from the connection, using handler provided // by atrium-streams-client, the `Firehose`. let mut subscription = Repositories::builder() .connection(connection) - .handler(Firehose) + .handler(firehose) .build(); // Receive payloads by calling `StreamExt::next()`. From d743eab5ab34a6e651cdfd94f9a3bab04a4c6c43 Mon Sep 17 00:00:00 2001 From: Elaina <48662592+oestradiol@users.noreply.github.com> Date: Fri, 13 Sep 2024 22:22:14 -0300 Subject: [PATCH 5/7] Adding tests for subscribeRepos firehose implementation --- Cargo.lock | 2 + atrium-streams-client/Cargo.toml | 6 +- .../src/subscriptions/repositories/mod.rs | 5 +- .../src/subscriptions/repositories/tests.rs | 491 ++++++++++++++++++ 4 files changed, 502 insertions(+), 2 deletions(-) create mode 100644 atrium-streams-client/src/subscriptions/repositories/tests.rs diff --git a/Cargo.lock b/Cargo.lock index cd61aa38..8a7903c1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -217,6 +217,7 @@ dependencies = [ name = "atrium-streams-client" version = "0.1.0" dependencies = [ + "anyhow", "async-stream", "atrium-streams", "atrium-xrpc", @@ -227,6 +228,7 @@ dependencies = [ "serde", "serde_html_form", "serde_ipld_dagcbor", + "serde_json", "thiserror", "tokio", "tokio-tungstenite", diff --git a/atrium-streams-client/Cargo.toml b/atrium-streams-client/Cargo.toml index 3b498ab0..af916392 100644 --- a/atrium-streams-client/Cargo.toml +++ b/atrium-streams-client/Cargo.toml @@ -24,4 +24,8 @@ tokio.workspace = true bon.workspace = true serde_html_form.workspace = true serde.workspace = true -thiserror.workspace = true \ No newline at end of file +thiserror.workspace = true + +[dev-dependencies] +anyhow.workspace = true +serde_json.workspace = true \ No newline at end of file diff --git a/atrium-streams-client/src/subscriptions/repositories/mod.rs b/atrium-streams-client/src/subscriptions/repositories/mod.rs index 289e3898..48e9f2ad 100644 --- a/atrium-streams-client/src/subscriptions/repositories/mod.rs +++ b/atrium-streams-client/src/subscriptions/repositories/mod.rs @@ -1,6 +1,9 @@ pub mod firehose; pub mod type_defs; +#[cfg(test)] +mod tests; + use std::marker::PhantomData; use async_stream::stream; @@ -56,7 +59,7 @@ impl Subscription for Repositories // "Invalid framing or invalid DAG-CBOR encoding are hard errors, // and the client should drop the entire connection instead of skipping the frame." // https://atproto.com/specs/event-stream - yield Err(SubscriptionError::Abort(format!("Received invalid frame. Error: {e:?}"))); + yield Err(SubscriptionError::Abort(format!("Received invalid packet. Error: {e:?}"))); break; } Some(Ok(Message::Binary(data))) => { diff --git a/atrium-streams-client/src/subscriptions/repositories/tests.rs b/atrium-streams-client/src/subscriptions/repositories/tests.rs new file mode 100644 index 00000000..c68059ed --- /dev/null +++ b/atrium-streams-client/src/subscriptions/repositories/tests.rs @@ -0,0 +1,491 @@ +use std::{convert::identity as id, vec}; + +use atrium_streams::{ + atrium_api::{ + com::atproto::{ + label::subscribe_labels::InfoData, + sync::subscribe_repos::{ + self, AccountData, CommitData, HandleData, IdentityData, MigrateData, TombstoneData, + }, + }, + types::{ + string::{Datetime, Did, Handle}, + CidLink, Object, + }, + }, + subscriptions::{ + handlers::repositories::HandledData, ConnectionHandler, ProcessedPayload, SubscriptionError, + }, +}; +use futures::{executor::block_on_stream, Stream}; +use ipld_core::{ + cid::{multihash::Multihash, Cid}, + ipld::Ipld, +}; +use serde_json::Value; +use tokio_tungstenite::tungstenite::{Error, Message}; + +use super::{firehose::Firehose, Repositories}; + +fn serialize_ipld(frame: &str) -> Result, anyhow::Error> { + if frame.is_empty() { + return Ok(vec![]); + } + + let json: Value = serde_json::from_str(frame)?; + let bytes = serde_ipld_dagcbor::to_vec(&json)?; + Ok(bytes) +} + +fn mock_connection<'a>( + packets: Vec<(&'a str, &'a str)>, +) -> impl Stream> + Unpin + 'a { + let mut stream = packets.into_iter().map(|(header, payload)| { + // Using Utf8 as an arbitrary tungstenite error + serialize_ipld(header) + .map(|mut v| { + serialize_ipld(payload) + .map(|mut p| { + Message::Binary({ + v.append(&mut p); + v + }) + }) + .map_err(|_| Error::Utf8) + }) + .map_err(|_| Error::Utf8) + .and_then(id) + }); + let connection = async_stream::stream! { + while let Some(packet) = stream.next() { + yield packet; + } + }; + + Box::pin(connection) +} + +fn test_packet( + packet: Option<(&str, &str)>, +) -> Option, HandledData), SubscriptionError>> +{ + let connection = mock_connection(if let Some(packet) = packet { + vec![packet] + } else { + vec![] + }); + + let subscription = gen_default_subscription(connection); + + block_on_stream(subscription) + .next() + .map(|v| v.map(|ProcessedPayload { data, seq }| (seq, data))) +} + +fn gen_default_subscription( + connection: impl Stream> + Unpin, +) -> impl Stream< + Item = Result< + ProcessedPayload<::HandledData>, + SubscriptionError, + >, +> { + let firehose = Firehose::builder() + .enable_commit(true) + .enable_identity(true) + .enable_account(true) + .enable_handle(true) + .enable_migrate(true) + .enable_tombstone(true) + .enable_info(true) + .build(); + let subscription = Repositories::builder() + .connection(connection) + .handler(firehose) + .build(); + subscription +} + +#[test] +fn disconnect() { + if test_packet(None).is_none() { + return; + } + panic!("Expected None") +} + +#[test] +fn invalid_packet() { + if let SubscriptionError::Abort(msg) = + test_packet(Some(("{ not-a-header }", "{ not-a-payload }"))) + .unwrap() + .unwrap_err() + { + assert_eq!(msg, "Received invalid packet. Error: Utf8"); + return; + } + panic!("Expected Invalid Packet") +} + +#[test] +fn commit() { + let now = Datetime::now(); + let now_str = format!("{:?}", now); + let body = Object { + data: Some(CommitData { + blobs: vec![], + blocks: vec![], + commit: CidLink(Cid::new_v1( + 0x70, + Multihash::<64>::wrap(0x12, &[0; 64]).unwrap(), + )), + ops: vec![], + prev: None, + rebase: false, + repo: Did::new("did:plc:ewvi7nxzyoun6zhxrhs64oiz".to_string()).unwrap(), + rev: String::new(), + seq: 99, + since: None, + time: now, + too_big: true, + }), + extra_data: Ipld::Null, + }; + let body = serde_json::to_string(&body).unwrap(); + let (seq, data) = test_packet(Some((r##"{ "op": 1, "t": "#commit" }"##, &body))) + .unwrap() + .unwrap(); + assert_eq!(seq, Some(99)); + assert_eq!( + format!("{:?}", data), + format!( + "Commit(ProcessedCommitData {{ \ + repo: Did(\"did:plc:ewvi7nxzyoun6zhxrhs64oiz\"), \ + commit: CidLink(Cid(bafybeqaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\ + aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa)), \ + ops: None, \ + blobs: [], \ + rev: \"\", \ + since: None, \ + time: {now_str} \ + }})" + ) + ); +} + +#[test] +fn identity() { + let now = Datetime::now(); + let now_str = format!("{:?}", now); + let body = Object { + data: IdentityData { + did: Did::new("did:plc:ewvi7nxzyoun6zhxrhs64oiz".to_string()).unwrap(), + handle: None, + seq: 99, + time: now, + }, + extra_data: Ipld::Null, + }; + let body = serde_json::to_string(&body).unwrap(); + let (seq, data) = test_packet(Some((r##"{ "op": 1, "t": "#identity" }"##, &body))) + .unwrap() + .unwrap(); + assert_eq!(seq, Some(99)); + assert_eq!( + format!("{:?}", data), + format!( + "Identity(ProcessedIdentityData {{ \ + did: Did(\"did:plc:ewvi7nxzyoun6zhxrhs64oiz\"), \ + handle: None, \ + time: {now_str} \ + }})" + ) + ); +} + +#[test] +fn account() { + let now = Datetime::now(); + let now_str = format!("{:?}", now); + let body = Object { + data: AccountData { + active: false, + did: Did::new("did:plc:ewvi7nxzyoun6zhxrhs64oiz".to_string()).unwrap(), + seq: 99, + status: None, + time: now, + }, + extra_data: Ipld::Null, + }; + let body = serde_json::to_string(&body).unwrap(); + let (seq, data) = test_packet(Some((r##"{ "op": 1, "t": "#account" }"##, &body))) + .unwrap() + .unwrap(); + assert_eq!(seq, Some(99)); + assert_eq!( + format!("{:?}", data), + format!( + "Account(ProcessedAccountData {{ \ + did: Did(\"did:plc:ewvi7nxzyoun6zhxrhs64oiz\"), \ + active: false, \ + status: None, \ + time: {now_str} \ + }})" + ) + ); +} + +#[test] +fn handle() { + let now = Datetime::now(); + let now_str = format!("{:?}", now); + let body = Object { + data: HandleData { + did: Did::new("did:plc:ewvi7nxzyoun6zhxrhs64oiz".to_string()).unwrap(), + handle: Handle::new("test.handle.xyz".to_string()).unwrap(), + seq: 99, + time: now, + }, + extra_data: Ipld::Null, + }; + let body = serde_json::to_string(&body).unwrap(); + let (seq, data) = test_packet(Some((r##"{ "op": 1, "t": "#handle" }"##, &body))) + .unwrap() + .unwrap(); + assert_eq!(seq, Some(99)); + assert_eq!( + format!("{:?}", data), + format!( + "Handle(ProcessedHandleData {{ \ + did: Did(\"did:plc:ewvi7nxzyoun6zhxrhs64oiz\"), \ + handle: Handle(\"test.handle.xyz\"), \ + time: {now_str} \ + }})" + ) + ); +} + +#[test] +fn migrate() { + let now = Datetime::now(); + let now_str = format!("{:?}", now); + let body = Object { + data: MigrateData { + did: Did::new("did:plc:ewvi7nxzyoun6zhxrhs64oiz".to_string()).unwrap(), + migrate_to: None, + seq: 99, + time: now, + }, + extra_data: Ipld::Null, + }; + let body = serde_json::to_string(&body).unwrap(); + let (seq, data) = test_packet(Some((r##"{ "op": 1, "t": "#migrate" }"##, &body))) + .unwrap() + .unwrap(); + assert_eq!(seq, Some(99)); + assert_eq!( + format!("{:?}", data), + format!( + "Migrate(ProcessedMigrateData {{ \ + did: Did(\"did:plc:ewvi7nxzyoun6zhxrhs64oiz\"), \ + migrate_to: None, \ + time: {now_str} \ + }})" + ) + ); +} + +#[test] +fn tombstone() { + let now = Datetime::now(); + let now_str = format!("{:?}", now); + let body = Object { + data: TombstoneData { + did: Did::new("did:plc:ewvi7nxzyoun6zhxrhs64oiz".to_string()).unwrap(), + seq: 99, + time: now, + }, + extra_data: Ipld::Null, + }; + let body = serde_json::to_string(&body).unwrap(); + let (seq, data) = test_packet(Some((r##"{ "op": 1, "t": "#tombstone" }"##, &body))) + .unwrap() + .unwrap(); + assert_eq!(seq, Some(99)); + assert_eq!( + format!("{:?}", data), + format!( + "Tombstone(ProcessedTombstoneData {{ \ + did: Did(\"did:plc:ewvi7nxzyoun6zhxrhs64oiz\"), \ + time: {now_str} \ + }})" + ) + ); +} + +#[test] +fn info() { + let body = Object { + data: InfoData { + message: Some("Requested cursor exceeded limit. Possibly missing events".to_string()), + name: "OutdatedCursor".to_string(), + }, + extra_data: Ipld::Null, + }; + let body = serde_json::to_string(&body).unwrap(); + let (seq, data) = test_packet(Some((r##"{ "op": 1, "t": "#info" }"##, &body))) + .unwrap() + .unwrap(); + assert_eq!(seq, None); + assert_eq!( + format!("{:?}", data), + "Info(InfoData { \ + message: Some(\"Requested cursor exceeded limit. Possibly missing events\"), \ + name: \"OutdatedCursor\" \ + })" + .to_string() + ); +} + +#[test] +fn ignored_frame() { + if test_packet(Some(( + r##"{ "op": 1, "t": "#non-existent" }"##, + r#"{ "foo": "bar" }"#, + ))) + .is_none() + { + return; + } + panic!("Expected None") +} + +#[test] +fn invalid_body() { + let body = Object { + data: Some(CommitData { + blobs: vec![], + blocks: vec![1], // Invalid CAR file + commit: CidLink(Cid::new_v1( + 0x70, + Multihash::<64>::wrap(0x12, &[0; 64]).unwrap(), + )), + ops: vec![], + prev: None, + rebase: false, + repo: Did::new("did:plc:ewvi7nxzyoun6zhxrhs64oiz".to_string()).unwrap(), + rev: String::new(), + seq: 0, + since: None, + time: Datetime::now(), + too_big: false, + }), + extra_data: Ipld::Null, + }; + let body = serde_json::to_string(&body).unwrap(); + if let SubscriptionError::Abort(msg) = + test_packet(Some((r##"{ "op": 1, "t": "#commit" }"##, &body))) + .unwrap() + .unwrap_err() + { + assert_eq!( + msg, + "Received invalid payload. Error: CarDecoding(IoError(Kind(UnexpectedEof)))" + ); + return; + } +} + +#[test] +fn future_cursor() { + let res = test_packet(Some(( + r##"{ "op": -1 }"##, + r#"{ "error": "FutureCursor", "message": "Cursor in the future." }"#, + ))); + if let SubscriptionError::Other(subscribe_repos::Error::FutureCursor(Some(s))) = + res.unwrap().unwrap_err() + { + assert_eq!("Cursor in the future.", &s); + return; + } + panic!("Expected FutureCursor") +} + +#[test] +fn consumer_too_slow() { + let res = test_packet(Some(( + r##"{ "op": -1 }"##, + r#"{ "error": "ConsumerTooSlow", "message": "Stream consumer too slow" }"#, + ))); + if let SubscriptionError::Other(subscribe_repos::Error::ConsumerTooSlow(Some(s))) = + res.unwrap().unwrap_err() + { + assert_eq!("Stream consumer too slow", &s); + return; + } + panic!("Expected ConsumerTooSlow") +} + +#[test] +fn unknown_error() { + let res = test_packet(Some(( + r##"{ "op": -1 }"##, + r#"{ "error": "Unknown", "message": "No one knows" }"#, + ))); + if let SubscriptionError::Unknown(msg) = res.unwrap().unwrap_err() { + assert_eq!( + "Failed to decode error frame: \ + Msg(\"unknown variant `Unknown`, expected `FutureCursor` or `ConsumerTooSlow`\")", + &msg + ); + return; + } + panic!("Expected Unknown") +} + +#[test] +fn empty_payload() { + let res = test_packet(Some((r##"{ "op": 1, "t": "#commit" }"##, r#""#))); + if let SubscriptionError::Abort(msg) = res.unwrap().unwrap_err() { + assert_eq!( + "Received empty payload for header: {\"op\": 1, \"t\": \"#commit\"}", + &msg + ); + return; + } + panic!("Expected Empty Payload") +} + +#[test] +fn invalid_frame() { + fn mock_invalid() -> impl Stream> + Unpin { + let mut stream = vec![Message::Binary(vec![b'{'])].into_iter(); + let connection = async_stream::stream! { + while let Some(packet) = stream.next() { + yield Ok(packet); + } + }; + Box::pin(connection) + } + + let subscription = gen_default_subscription(mock_invalid()); + + let res = block_on_stream(subscription) + .next() + .map(|v| v.map(|ProcessedPayload { data, seq }| (seq, data))); + + if let SubscriptionError::Abort(msg) = res.unwrap().unwrap_err() { + assert_eq!("Received invalid frame. Error: Eof", &msg); + return; + } + panic!("Expected Invalid Frame") +} + +#[test] +fn unknown_frame() { + let res = test_packet(Some((r##"{ "op": 2 }"##, r#"{ "unknown": "header" }"#))); + if res.is_none() { + return; + } + panic!("Expected None") +} From 5ddbf3322b9cb480b77e4497d8b02fc3bc28d3c8 Mon Sep 17 00:00:00 2001 From: Elaina <48662592+oestradiol@users.noreply.github.com> Date: Sat, 14 Sep 2024 00:54:20 -0300 Subject: [PATCH 6/7] Adding tests for client + deleting XrpcUri --- atrium-streams-client/Cargo.toml | 3 +- atrium-streams-client/src/client.rs | 87 ----------------- atrium-streams-client/src/client/mod.rs | 96 +++++++++++++++++++ atrium-streams-client/src/client/tests.rs | 88 +++++++++++++++++ .../src/{client/mod.rs => client.rs} | 4 +- atrium-streams/src/client/xprc_uri.rs | 16 ---- examples/firehose/src/main.rs | 13 ++- 7 files changed, 193 insertions(+), 114 deletions(-) delete mode 100644 atrium-streams-client/src/client.rs create mode 100644 atrium-streams-client/src/client/mod.rs create mode 100644 atrium-streams-client/src/client/tests.rs rename atrium-streams/src/{client/mod.rs => client.rs} (94%) delete mode 100644 atrium-streams/src/client/xprc_uri.rs diff --git a/atrium-streams-client/Cargo.toml b/atrium-streams-client/Cargo.toml index af916392..2c399bda 100644 --- a/atrium-streams-client/Cargo.toml +++ b/atrium-streams-client/Cargo.toml @@ -28,4 +28,5 @@ thiserror.workspace = true [dev-dependencies] anyhow.workspace = true -serde_json.workspace = true \ No newline at end of file +serde_json.workspace = true +tokio = { version = "1.37", default-features = false, features = ["rt-multi-thread"] } \ No newline at end of file diff --git a/atrium-streams-client/src/client.rs b/atrium-streams-client/src/client.rs deleted file mode 100644 index 67ea9021..00000000 --- a/atrium-streams-client/src/client.rs +++ /dev/null @@ -1,87 +0,0 @@ -//! This file provides a client for the `ATProto` XRPC over WSS protocol. -//! It implements the [`EventStreamClient`] trait for the [`WssClient`] struct. - -use std::str::FromStr; - -use futures::Stream; -use tokio::net::TcpStream; - -use atrium_xrpc::{ - http::{Request, Uri}, - types::Header, -}; -use bon::Builder; -use serde::Serialize; -use tokio_tungstenite::{ - connect_async, - tungstenite::{self, handshake::client::generate_key}, - MaybeTlsStream, WebSocketStream, -}; - -use atrium_streams::client::{EventStreamClient, XrpcUri}; - -/// An enum of possible error kinds for this crate. -#[derive(thiserror::Error, Debug)] -pub enum Error { - #[error("Invalid uri")] - InvalidUri, - #[error("Parsing parameters failed: {0}")] - ParsingParameters(#[from] serde_html_form::ser::Error), - #[error("Connection error: {0}")] - Connection(#[from] tungstenite::Error), -} - -#[derive(Builder)] -pub struct WssClient<'a, P: Serialize> { - xrpc_uri: XrpcUri<'a>, - params: Option

, -} - -type StreamKind = WebSocketStream>; -impl EventStreamClient<::Item, Error> - for WssClient<'_, P> -{ - async fn connect(&self) -> Result::Item>, Error> { - let Self { xrpc_uri, params } = self; - let mut uri = xrpc_uri.to_uri(); - //// Query parameters - if let Some(p) = ¶ms { - uri.push('?'); - uri += &serde_html_form::to_string(p)?; - }; - //// - - //// Request - // Extracting the authority from the URI to set the Host header. - let uri = Uri::from_str(&uri).map_err(|_| Error::InvalidUri)?; - let authority = uri.authority().ok_or_else(|| Error::InvalidUri)?.as_str(); - let host = authority - .find('@') - .map_or_else(|| authority, |idx| authority.split_at(idx + 1).1); - - // Building the request. - let mut request = Request::builder() - .uri(&uri) - .method("GET") - .header("Host", host) - .header("Connection", "Upgrade") - .header("Upgrade", "websocket") - .header("Sec-WebSocket-Version", "13") - .header("Sec-WebSocket-Key", generate_key()); - - // Adding the ATProto headers. - if let Some(proxy) = self.atproto_proxy_header().await { - request = request.header(Header::AtprotoProxy, proxy); - } - if let Some(accept_labelers) = self.atproto_accept_labelers_header().await { - request = request.header(Header::AtprotoAcceptLabelers, accept_labelers.join(", ")); - } - - // In our case, the only thing that could possibly fail is the URI. The headers are all `String`/`&str`. - let request = request.body(()).map_err(|_| Error::InvalidUri)?; - //// - - let (stream, _) = connect_async(request).await?; - Ok(stream) - } -} diff --git a/atrium-streams-client/src/client/mod.rs b/atrium-streams-client/src/client/mod.rs new file mode 100644 index 00000000..0d6a048d --- /dev/null +++ b/atrium-streams-client/src/client/mod.rs @@ -0,0 +1,96 @@ +//! This file provides a client for the `ATProto` XRPC over WSS protocol. +//! It implements the [`EventStreamClient`] trait for the [`WssClient`] struct. + +#[cfg(test)] +mod tests; + +use std::str::FromStr; + +use futures::Stream; +use tokio::net::TcpStream; + +use atrium_xrpc::{ + http::{Request, Uri}, + types::Header, +}; +use bon::Builder; +use serde::Serialize; +use tokio_tungstenite::{ + connect_async, + tungstenite::{self, handshake::client::generate_key}, + MaybeTlsStream, WebSocketStream, +}; + +use atrium_streams::client::EventStreamClient; + +/// An enum of possible error kinds for this crate. +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("Invalid uri")] + InvalidUri, + #[error("Parsing parameters failed: {0}")] + ParsingParameters(#[from] serde_html_form::ser::Error), + #[error("Connection error: {0}")] + Connection(#[from] tungstenite::Error), +} + +#[derive(Builder)] +pub struct WssClient { + params: Option

, +} + +type StreamKind = WebSocketStream>; +impl EventStreamClient<::Item, Error> + for WssClient

+{ + async fn connect(&self, mut uri: String) -> Result::Item>, Error> { + let Self { params } = self; + + // Query parameters + if let Some(p) = ¶ms { + uri.push('?'); + uri += &serde_html_form::to_string(p)?; + }; + + // Request + let (uri, host) = get_host(&uri)?; + let request = gen_request(self, &uri, &*host).await?; + + // Connection + let (stream, _) = connect_async(request).await?; + Ok(stream) + } +} + +/// Extract the URI and host from a string. +fn get_host(uri: &str) -> Result<(Uri, Box), Error> { + let uri = Uri::from_str(uri).map_err(|_| Error::InvalidUri)?; + let authority = uri.authority().ok_or_else(|| Error::InvalidUri)?.as_str(); + let host = authority + .find('@') + .map_or_else(|| authority, |idx| authority.split_at(idx + 1).1); + let host = Box::from(host); + Ok((uri, host)) +} + +/// Generate a request for the given URI and host. +/// It sets the necessary headers for a WebSocket connection, +/// plus the client's `AtprotoProxy` and `AtprotoAcceptLabelers` headers. +async fn gen_request(client: &WssClient

, uri: &Uri, host: &str) -> Result, Error> { + let mut request = Request::builder() + .uri(uri) + .method("GET") + .header("Host", host) + .header("Connection", "Upgrade") + .header("Upgrade", "websocket") + .header("Sec-WebSocket-Version", "13") + .header("Sec-WebSocket-Key", generate_key()); + if let Some(proxy) = client.atproto_proxy_header().await { + request = request.header(Header::AtprotoProxy, proxy); + } + if let Some(accept_labelers) = client.atproto_accept_labelers_header().await { + request = request.header(Header::AtprotoAcceptLabelers, accept_labelers.join(", ")); + } + let request = request.body(()).map_err(|_| Error::InvalidUri)?; + Ok(request) +} diff --git a/atrium-streams-client/src/client/tests.rs b/atrium-streams-client/src/client/tests.rs new file mode 100644 index 00000000..57d27cb0 --- /dev/null +++ b/atrium-streams-client/src/client/tests.rs @@ -0,0 +1,88 @@ +use std::net::{Ipv4Addr, SocketAddr}; + +use atrium_streams::{atrium_api::com::atproto::sync::subscribe_repos, client::EventStreamClient}; +use atrium_xrpc::http::{header::SEC_WEBSOCKET_KEY, HeaderMap, HeaderValue}; +use futures::{SinkExt, StreamExt}; +use tokio::{net::{TcpListener, TcpStream}, runtime::Runtime}; +use tokio_tungstenite::{tungstenite::{handshake::server::{ErrorResponse, Request, Response}, Message}, WebSocketStream}; + +use crate::WssClient; + +use super::{gen_request, get_host}; + +#[test] +fn client() { + let fut = async { + let ipv4 = Ipv4Addr::LOCALHOST.to_string(); + let xrpc_uri = format!("ws://{ipv4}:3000/xrpc/{}", subscribe_repos::NSID); + let (client, mut client_headers) = wss_client(&xrpc_uri).await; + + let server_handle = tokio::spawn(mock_wss_server()); + let mut client_stream = client.connect(xrpc_uri).await.unwrap(); + let (server_stream, mut server_headers, route) = server_handle.await.unwrap(); + + assert_eq!(route, format!("/xrpc/{}", subscribe_repos::NSID)); + + client_headers.remove(SEC_WEBSOCKET_KEY); + server_headers.remove(SEC_WEBSOCKET_KEY); + assert_eq!(client_headers, server_headers); + + let (mut inbound, _) = server_stream.split(); + inbound.send(Message::text("test_message")).await.unwrap(); + let msg = client_stream.next().await.unwrap().unwrap(); + assert_eq!(msg, Message::text("test_message")); + }; + Runtime::new().unwrap().block_on(fut); +} + +async fn wss_client(uri: &str) -> (WssClient, HeaderMap) { + let params = subscribe_repos::ParametersData { + cursor: None, + }; + + let client = WssClient::builder() + .params(params) + .build(); + + let (uri, host) = get_host(uri).unwrap(); + let req = gen_request(&client, &uri, &host).await.unwrap(); + let headers = req.headers(); + + (client, headers.clone()) +} + +async fn mock_wss_server() -> (WebSocketStream, HeaderMap, String) { + let sock_addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 3000)); + + let listener = TcpListener::bind(sock_addr) + .await + .expect("Failed to bind to port!"); + + let headers: HeaderMap; + let route: String; + let (stream, _) = listener.accept().await.unwrap(); + let (headers_, route_, stream) = extract_headers(stream).await; + headers = headers_; + route = route_; + + (stream, headers, route) +} + +async fn extract_headers(raw_stream: TcpStream) -> (HeaderMap, String, WebSocketStream) { + let mut headers: Option> = None; + let mut route: Option = None; + + let copy_headers_callback = |request: &Request, response: Response| -> Result { + headers = Some(request.headers().clone()); + route = Some(request.uri().path().to_owned()); + Ok(response) + }; + + let stream = tokio_tungstenite::accept_hdr_async( + raw_stream, + copy_headers_callback, + ).await + .expect("Error during the websocket handshake occurred"); + + (headers.unwrap(), route.unwrap(), stream) +} \ No newline at end of file diff --git a/atrium-streams/src/client/mod.rs b/atrium-streams/src/client.rs similarity index 94% rename from atrium-streams/src/client/mod.rs rename to atrium-streams/src/client.rs index 3248a3d7..1f656749 100644 --- a/atrium-streams/src/client/mod.rs +++ b/atrium-streams/src/client.rs @@ -1,9 +1,6 @@ -mod xprc_uri; - use std::future::Future; use futures::Stream; -pub use xprc_uri::XrpcUri; /// An abstract WSS client. pub trait EventStreamClient { @@ -13,6 +10,7 @@ pub trait EventStreamClient { /// [`Result`] fn connect( &self, + uri: String ) -> impl Future, ConnectionError>> + Send; /// Get the `atproto-proxy` header. diff --git a/atrium-streams/src/client/xprc_uri.rs b/atrium-streams/src/client/xprc_uri.rs deleted file mode 100644 index 89d5a89f..00000000 --- a/atrium-streams/src/client/xprc_uri.rs +++ /dev/null @@ -1,16 +0,0 @@ -/// The URI for the XRPC `WebSocket` connection. -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct XrpcUri<'a> { - base_uri: &'a str, - nsid: &'a str, -} -impl<'a> XrpcUri<'a> { - pub const fn new(base_uri: &'a str, nsid: &'a str) -> Self { - Self { base_uri, nsid } - } - - pub fn to_uri(&self) -> String { - let XrpcUri { base_uri, nsid } = self; - format!("wss://{base_uri}/xrpc/{nsid}") - } -} diff --git a/examples/firehose/src/main.rs b/examples/firehose/src/main.rs index 8334f01c..ebd9e8b7 100644 --- a/examples/firehose/src/main.rs +++ b/examples/firehose/src/main.rs @@ -2,7 +2,7 @@ use anyhow::bail; use atrium_streams_client::{ atrium_streams::{ atrium_api::com::atproto::sync::subscribe_repos::{self, InfoData}, - client::{XrpcUri, EventStreamClient}, + client::EventStreamClient, subscriptions::{ handlers::repositories::ProcessedData, ProcessedPayload, SubscriptionError, }, @@ -20,19 +20,19 @@ use tokio_tungstenite::tungstenite; /// This example demonstrates how to connect to the ATProto Firehose. #[tokio::main] async fn main() { - // Define the XrpcUri for the subscription. - let xrpc_uri = XrpcUri::new("bsky.network", subscribe_repos::NSID); + // Define the Uri for the subscription. + let uri = format!("wss://bsky.network/xrpc/{}", subscribe_repos::NSID); // Caching the last cursor is important. // The API has a backfilling mechanism that allows you to resume from where you stopped. let mut last_cursor = None; - drop(connect(&mut last_cursor, &xrpc_uri).await); + drop(connect(&mut last_cursor, uri).await); } /// Connects to `ATProto` to receive real-time data. async fn connect( last_cursor: &mut Option, - xrpc_uri: &XrpcUri<'_>, + uri: String, ) -> Result<(), anyhow::Error> { // Define the query parameters. In this case, just the cursor. let params = subscribe_repos::ParametersData { @@ -41,12 +41,11 @@ async fn connect( // Build a new XRPC WSS Client. let client = WssClient::builder() - .xrpc_uri(xrpc_uri.clone()) .params(params) .build(); // And then we connect to the API. - let connection = match client.connect().await { + let connection = match client.connect(uri).await { Ok(connection) => connection, Err(Error::Connection(tungstenite::Error::Http(response))) => { // According to the API documentation, the following status codes are expected and should be treated accordingly: From 02ef85fa81b74e1948c0f9c9c241f0f405fcc13b Mon Sep 17 00:00:00 2001 From: Elaina <48662592+oestradiol@users.noreply.github.com> Date: Thu, 19 Sep 2024 18:06:41 -0300 Subject: [PATCH 7/7] Formatting --- atrium-streams-client/src/client/mod.rs | 15 ++-- atrium-streams-client/src/client/tests.rs | 56 +++++++------ .../subscriptions/repositories/firehose.rs | 82 +++---------------- .../src/subscriptions/repositories/tests.rs | 76 ++++++----------- atrium-streams/src/client.rs | 2 +- .../src/subscriptions/frames/tests.rs | 9 +- atrium-streams/src/subscriptions/mod.rs | 5 +- 7 files changed, 80 insertions(+), 165 deletions(-) diff --git a/atrium-streams-client/src/client/mod.rs b/atrium-streams-client/src/client/mod.rs index 0d6a048d..29e694e7 100644 --- a/atrium-streams-client/src/client/mod.rs +++ b/atrium-streams-client/src/client/mod.rs @@ -43,7 +43,10 @@ type StreamKind = WebSocketStream>; impl EventStreamClient<::Item, Error> for WssClient

{ - async fn connect(&self, mut uri: String) -> Result::Item>, Error> { + async fn connect( + &self, + mut uri: String, + ) -> Result::Item>, Error> { let Self { params } = self; // Query parameters @@ -66,9 +69,7 @@ impl EventStreamClient<::Item, fn get_host(uri: &str) -> Result<(Uri, Box), Error> { let uri = Uri::from_str(uri).map_err(|_| Error::InvalidUri)?; let authority = uri.authority().ok_or_else(|| Error::InvalidUri)?.as_str(); - let host = authority - .find('@') - .map_or_else(|| authority, |idx| authority.split_at(idx + 1).1); + let host = authority.find('@').map_or_else(|| authority, |idx| authority.split_at(idx + 1).1); let host = Box::from(host); Ok((uri, host)) } @@ -76,7 +77,11 @@ fn get_host(uri: &str) -> Result<(Uri, Box), Error> { /// Generate a request for the given URI and host. /// It sets the necessary headers for a WebSocket connection, /// plus the client's `AtprotoProxy` and `AtprotoAcceptLabelers` headers. -async fn gen_request(client: &WssClient

, uri: &Uri, host: &str) -> Result, Error> { +async fn gen_request( + client: &WssClient

, + uri: &Uri, + host: &str, +) -> Result, Error> { let mut request = Request::builder() .uri(uri) .method("GET") diff --git a/atrium-streams-client/src/client/tests.rs b/atrium-streams-client/src/client/tests.rs index 57d27cb0..d7265ab3 100644 --- a/atrium-streams-client/src/client/tests.rs +++ b/atrium-streams-client/src/client/tests.rs @@ -3,8 +3,17 @@ use std::net::{Ipv4Addr, SocketAddr}; use atrium_streams::{atrium_api::com::atproto::sync::subscribe_repos, client::EventStreamClient}; use atrium_xrpc::http::{header::SEC_WEBSOCKET_KEY, HeaderMap, HeaderValue}; use futures::{SinkExt, StreamExt}; -use tokio::{net::{TcpListener, TcpStream}, runtime::Runtime}; -use tokio_tungstenite::{tungstenite::{handshake::server::{ErrorResponse, Request, Response}, Message}, WebSocketStream}; +use tokio::{ + net::{TcpListener, TcpStream}, + runtime::Runtime, +}; +use tokio_tungstenite::{ + tungstenite::{ + handshake::server::{ErrorResponse, Request, Response}, + Message, + }, + WebSocketStream, +}; use crate::WssClient; @@ -35,15 +44,13 @@ fn client() { Runtime::new().unwrap().block_on(fut); } -async fn wss_client(uri: &str) -> (WssClient, HeaderMap) { - let params = subscribe_repos::ParametersData { - cursor: None, - }; +async fn wss_client( + uri: &str, +) -> (WssClient, HeaderMap) { + let params = subscribe_repos::ParametersData { cursor: None }; + + let client = WssClient::builder().params(params).build(); - let client = WssClient::builder() - .params(params) - .build(); - let (uri, host) = get_host(uri).unwrap(); let req = gen_request(&client, &uri, &host).await.unwrap(); let headers = req.headers(); @@ -54,11 +61,9 @@ async fn wss_client(uri: &str) -> (WssClient, H async fn mock_wss_server() -> (WebSocketStream, HeaderMap, String) { let sock_addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 3000)); - let listener = TcpListener::bind(sock_addr) - .await - .expect("Failed to bind to port!"); + let listener = TcpListener::bind(sock_addr).await.expect("Failed to bind to port!"); - let headers: HeaderMap; + let headers: HeaderMap; let route: String; let (stream, _) = listener.accept().await.unwrap(); let (headers_, route_, stream) = extract_headers(stream).await; @@ -68,21 +73,22 @@ async fn mock_wss_server() -> (WebSocketStream, HeaderMap, String) { (stream, headers, route) } -async fn extract_headers(raw_stream: TcpStream) -> (HeaderMap, String, WebSocketStream) { +async fn extract_headers( + raw_stream: TcpStream, +) -> (HeaderMap, String, WebSocketStream) { let mut headers: Option> = None; let mut route: Option = None; - let copy_headers_callback = |request: &Request, response: Response| -> Result { - headers = Some(request.headers().clone()); - route = Some(request.uri().path().to_owned()); - Ok(response) - }; + let copy_headers_callback = + |request: &Request, response: Response| -> Result { + headers = Some(request.headers().clone()); + route = Some(request.uri().path().to_owned()); + Ok(response) + }; - let stream = tokio_tungstenite::accept_hdr_async( - raw_stream, - copy_headers_callback, - ).await + let stream = tokio_tungstenite::accept_hdr_async(raw_stream, copy_headers_callback) + .await .expect("Error during the websocket handshake occurred"); (headers.unwrap(), route.unwrap(), stream) -} \ No newline at end of file +} diff --git a/atrium-streams-client/src/subscriptions/repositories/firehose.rs b/atrium-streams-client/src/subscriptions/repositories/firehose.rs index e061f617..e07bde2f 100644 --- a/atrium-streams-client/src/subscriptions/repositories/firehose.rs +++ b/atrium-streams-client/src/subscriptions/repositories/firehose.rs @@ -136,19 +136,8 @@ impl Handler for Firehose { &self, payload: subscribe_repos::Commit, ) -> Result>, Self::HandlingError> { - let CommitData { - blobs, - blocks, - commit, - ops, - repo, - rev, - seq, - since, - time, - too_big, - .. - } = payload.data; + let CommitData { blobs, blocks, commit, ops, repo, rev, seq, since, time, too_big, .. } = + payload.data; // If it is too big, the blocks and ops are not sent, so we skip the processing. let ops_opt = if too_big { @@ -172,15 +161,7 @@ impl Handler for Firehose { Ok(Some(ProcessedPayload { seq: Some(seq), - data: Self::ProcessedCommitData { - ops: ops_opt, - blobs, - commit, - repo, - rev, - since, - time, - }, + data: Self::ProcessedCommitData { ops: ops_opt, blobs, commit, repo, rev, since, time }, })) } @@ -189,12 +170,7 @@ impl Handler for Firehose { &self, payload: subscribe_repos::Identity, ) -> Result>, Self::HandlingError> { - let IdentityData { - did, - handle, - seq, - time, - } = payload.data; + let IdentityData { did, handle, seq, time } = payload.data; Ok(Some(ProcessedPayload { seq: Some(seq), data: Self::ProcessedIdentityData { did, handle, time }, @@ -206,21 +182,10 @@ impl Handler for Firehose { &self, payload: subscribe_repos::Account, ) -> Result>, Self::HandlingError> { - let AccountData { - did, - seq, - time, - active, - status, - } = payload.data; + let AccountData { did, seq, time, active, status } = payload.data; Ok(Some(ProcessedPayload { seq: Some(seq), - data: Self::ProcessedAccountData { - did, - active, - status, - time, - }, + data: Self::ProcessedAccountData { did, active, status, time }, })) } @@ -229,12 +194,7 @@ impl Handler for Firehose { &self, payload: subscribe_repos::Handle, ) -> Result>, Self::HandlingError> { - let HandleData { - did, - handle, - seq, - time, - } = payload.data; + let HandleData { did, handle, seq, time } = payload.data; Ok(Some(ProcessedPayload { seq: Some(seq), data: Self::ProcessedHandleData { did, handle, time }, @@ -246,19 +206,10 @@ impl Handler for Firehose { &self, payload: subscribe_repos::Migrate, ) -> Result>, Self::HandlingError> { - let MigrateData { - did, - migrate_to, - seq, - time, - } = payload.data; + let MigrateData { did, migrate_to, seq, time } = payload.data; Ok(Some(ProcessedPayload { seq: Some(seq), - data: Self::ProcessedMigrateData { - did, - migrate_to, - time, - }, + data: Self::ProcessedMigrateData { did, migrate_to, time }, })) } @@ -279,10 +230,7 @@ impl Handler for Firehose { &self, payload: subscribe_repos::Info, ) -> Result>, Self::HandlingError> { - Ok(Some(ProcessedPayload { - seq: None, - data: payload.data, - })) + Ok(Some(ProcessedPayload { seq: None, data: payload.data })) } } @@ -315,15 +263,9 @@ fn process_op( // Finds in the map the `Record` with the operation's CID and deserializes it. // If the item is not found, returns `None`. let record = match cid.as_ref().and_then(|c| map.get_mut(&c.0)) { - Some(item) => Some(serde_ipld_dagcbor::from_reader::( - Cursor::new(item), - )?), + Some(item) => Some(serde_ipld_dagcbor::from_reader::(Cursor::new(item))?), None => None, }; - Ok(Operation { - action, - path, - record, - }) + Ok(Operation { action, path, record }) } diff --git a/atrium-streams-client/src/subscriptions/repositories/tests.rs b/atrium-streams-client/src/subscriptions/repositories/tests.rs index c68059ed..d94dd78a 100644 --- a/atrium-streams-client/src/subscriptions/repositories/tests.rs +++ b/atrium-streams-client/src/subscriptions/repositories/tests.rs @@ -69,11 +69,7 @@ fn test_packet( packet: Option<(&str, &str)>, ) -> Option, HandledData), SubscriptionError>> { - let connection = mock_connection(if let Some(packet) = packet { - vec![packet] - } else { - vec![] - }); + let connection = mock_connection(if let Some(packet) = packet { vec![packet] } else { vec![] }); let subscription = gen_default_subscription(connection); @@ -99,10 +95,7 @@ fn gen_default_subscription( .enable_tombstone(true) .enable_info(true) .build(); - let subscription = Repositories::builder() - .connection(connection) - .handler(firehose) - .build(); + let subscription = Repositories::builder().connection(connection).handler(firehose).build(); subscription } @@ -117,9 +110,7 @@ fn disconnect() { #[test] fn invalid_packet() { if let SubscriptionError::Abort(msg) = - test_packet(Some(("{ not-a-header }", "{ not-a-payload }"))) - .unwrap() - .unwrap_err() + test_packet(Some(("{ not-a-header }", "{ not-a-payload }"))).unwrap().unwrap_err() { assert_eq!(msg, "Received invalid packet. Error: Utf8"); return; @@ -135,10 +126,7 @@ fn commit() { data: Some(CommitData { blobs: vec![], blocks: vec![], - commit: CidLink(Cid::new_v1( - 0x70, - Multihash::<64>::wrap(0x12, &[0; 64]).unwrap(), - )), + commit: CidLink(Cid::new_v1(0x70, Multihash::<64>::wrap(0x12, &[0; 64]).unwrap())), ops: vec![], prev: None, rebase: false, @@ -152,9 +140,8 @@ fn commit() { extra_data: Ipld::Null, }; let body = serde_json::to_string(&body).unwrap(); - let (seq, data) = test_packet(Some((r##"{ "op": 1, "t": "#commit" }"##, &body))) - .unwrap() - .unwrap(); + let (seq, data) = + test_packet(Some((r##"{ "op": 1, "t": "#commit" }"##, &body))).unwrap().unwrap(); assert_eq!(seq, Some(99)); assert_eq!( format!("{:?}", data), @@ -187,9 +174,8 @@ fn identity() { extra_data: Ipld::Null, }; let body = serde_json::to_string(&body).unwrap(); - let (seq, data) = test_packet(Some((r##"{ "op": 1, "t": "#identity" }"##, &body))) - .unwrap() - .unwrap(); + let (seq, data) = + test_packet(Some((r##"{ "op": 1, "t": "#identity" }"##, &body))).unwrap().unwrap(); assert_eq!(seq, Some(99)); assert_eq!( format!("{:?}", data), @@ -218,9 +204,8 @@ fn account() { extra_data: Ipld::Null, }; let body = serde_json::to_string(&body).unwrap(); - let (seq, data) = test_packet(Some((r##"{ "op": 1, "t": "#account" }"##, &body))) - .unwrap() - .unwrap(); + let (seq, data) = + test_packet(Some((r##"{ "op": 1, "t": "#account" }"##, &body))).unwrap().unwrap(); assert_eq!(seq, Some(99)); assert_eq!( format!("{:?}", data), @@ -249,9 +234,8 @@ fn handle() { extra_data: Ipld::Null, }; let body = serde_json::to_string(&body).unwrap(); - let (seq, data) = test_packet(Some((r##"{ "op": 1, "t": "#handle" }"##, &body))) - .unwrap() - .unwrap(); + let (seq, data) = + test_packet(Some((r##"{ "op": 1, "t": "#handle" }"##, &body))).unwrap().unwrap(); assert_eq!(seq, Some(99)); assert_eq!( format!("{:?}", data), @@ -279,9 +263,8 @@ fn migrate() { extra_data: Ipld::Null, }; let body = serde_json::to_string(&body).unwrap(); - let (seq, data) = test_packet(Some((r##"{ "op": 1, "t": "#migrate" }"##, &body))) - .unwrap() - .unwrap(); + let (seq, data) = + test_packet(Some((r##"{ "op": 1, "t": "#migrate" }"##, &body))).unwrap().unwrap(); assert_eq!(seq, Some(99)); assert_eq!( format!("{:?}", data), @@ -308,9 +291,8 @@ fn tombstone() { extra_data: Ipld::Null, }; let body = serde_json::to_string(&body).unwrap(); - let (seq, data) = test_packet(Some((r##"{ "op": 1, "t": "#tombstone" }"##, &body))) - .unwrap() - .unwrap(); + let (seq, data) = + test_packet(Some((r##"{ "op": 1, "t": "#tombstone" }"##, &body))).unwrap().unwrap(); assert_eq!(seq, Some(99)); assert_eq!( format!("{:?}", data), @@ -333,9 +315,8 @@ fn info() { extra_data: Ipld::Null, }; let body = serde_json::to_string(&body).unwrap(); - let (seq, data) = test_packet(Some((r##"{ "op": 1, "t": "#info" }"##, &body))) - .unwrap() - .unwrap(); + let (seq, data) = + test_packet(Some((r##"{ "op": 1, "t": "#info" }"##, &body))).unwrap().unwrap(); assert_eq!(seq, None); assert_eq!( format!("{:?}", data), @@ -349,11 +330,8 @@ fn info() { #[test] fn ignored_frame() { - if test_packet(Some(( - r##"{ "op": 1, "t": "#non-existent" }"##, - r#"{ "foo": "bar" }"#, - ))) - .is_none() + if test_packet(Some((r##"{ "op": 1, "t": "#non-existent" }"##, r#"{ "foo": "bar" }"#))) + .is_none() { return; } @@ -366,10 +344,7 @@ fn invalid_body() { data: Some(CommitData { blobs: vec![], blocks: vec![1], // Invalid CAR file - commit: CidLink(Cid::new_v1( - 0x70, - Multihash::<64>::wrap(0x12, &[0; 64]).unwrap(), - )), + commit: CidLink(Cid::new_v1(0x70, Multihash::<64>::wrap(0x12, &[0; 64]).unwrap())), ops: vec![], prev: None, rebase: false, @@ -384,9 +359,7 @@ fn invalid_body() { }; let body = serde_json::to_string(&body).unwrap(); if let SubscriptionError::Abort(msg) = - test_packet(Some((r##"{ "op": 1, "t": "#commit" }"##, &body))) - .unwrap() - .unwrap_err() + test_packet(Some((r##"{ "op": 1, "t": "#commit" }"##, &body))).unwrap().unwrap_err() { assert_eq!( msg, @@ -447,10 +420,7 @@ fn unknown_error() { fn empty_payload() { let res = test_packet(Some((r##"{ "op": 1, "t": "#commit" }"##, r#""#))); if let SubscriptionError::Abort(msg) = res.unwrap().unwrap_err() { - assert_eq!( - "Received empty payload for header: {\"op\": 1, \"t\": \"#commit\"}", - &msg - ); + assert_eq!("Received empty payload for header: {\"op\": 1, \"t\": \"#commit\"}", &msg); return; } panic!("Expected Empty Payload") diff --git a/atrium-streams/src/client.rs b/atrium-streams/src/client.rs index 1f656749..5226c194 100644 --- a/atrium-streams/src/client.rs +++ b/atrium-streams/src/client.rs @@ -10,7 +10,7 @@ pub trait EventStreamClient { /// [`Result`] fn connect( &self, - uri: String + uri: String, ) -> impl Future, ConnectionError>> + Send; /// Get the `atproto-proxy` header. diff --git a/atrium-streams/src/subscriptions/frames/tests.rs b/atrium-streams/src/subscriptions/frames/tests.rs index bfab5a9d..10804549 100644 --- a/atrium-streams/src/subscriptions/frames/tests.rs +++ b/atrium-streams/src/subscriptions/frames/tests.rs @@ -7,10 +7,7 @@ fn serialized_data(s: &str) -> Vec { b'a'..=b'f' => b - b'a' + 10, _ => unreachable!(), }; - s.as_bytes() - .chunks(2) - .map(|b| (b2u(b[0]) << 4) + b2u(b[1])) - .collect() + s.as_bytes().chunks(2).map(|b| (b2u(b[0]) << 4) + b2u(b[1])).collect() } #[test] @@ -21,9 +18,7 @@ fn deserialize_message_frame_header() { let result = FrameHeader::try_from(ipld); assert_eq!( result.expect("failed to deserialize"), - FrameHeader::Message { - t: String::from("#commit") - } + FrameHeader::Message { t: String::from("#commit") } ); } diff --git a/atrium-streams/src/subscriptions/mod.rs b/atrium-streams/src/subscriptions/mod.rs index 67beca0a..46f113b7 100644 --- a/atrium-streams/src/subscriptions/mod.rs +++ b/atrium-streams/src/subscriptions/mod.rs @@ -53,10 +53,7 @@ pub struct ProcessedPayload { /// Helper function to convert between payload kinds. impl ProcessedPayload { pub fn map NewKind>(self, f: F) -> ProcessedPayload { - ProcessedPayload { - seq: self.seq, - data: f(self.data), - } + ProcessedPayload { seq: self.seq, data: f(self.data) } } }