diff --git a/CHANGELOG.md b/CHANGELOG.md index 9d00fae..e69de29 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,303 +0,0 @@ -# Unreleased - -# 0.13.0 - -- Add option to always save session. #216 - -# 0.12.3 - -- Ensure `continuously_delete_expired` waits for initial run. #208 - -# 0.12.2 - -- Ensure `set_expiry` mutates `Max-Age`. #191 - -This addresses a bug where using `set_expiry` on a session with no initial expiry time would not add the Max-age attribute to the cookie leading to an inconsitency between the cookie and the database. - -# 0.12.1 - -**Important Security Update** - -- Ensure ID cycling invokes `create`. #188 - -Because cycling the session ID involves creating a new ID, this must follow the same semantics as normal session creation. Therefore prior to this fix session ID collision could occur through this vector. - -# 0.12.0 - -**Important Security Update** - -- Id collision mitigation. #181 - -This release introduces a new method, `create`, to the `SessionStore` trait to distinguish between creating a new session and updating an existing one. **This distinction is crucial for mitigating the potential for session ID collisions.** - -Although the probability of session ID collisions is statistically low, given that IDs are composed of securely-random `i128` values, such collisions pose a significant security risk. A store that does not differentiate between session creation and updates could inadvertently allow an existing session to be accessed, leading to potential session takeovers. - -Session store authors are strongly encouraged to update and implement `create` such that potential ID collisions are handled, either by generating a new ID or returning an error. - -As a transitional measure, we have provided a default implementation of `create` that wraps the existing `save` method. However, this default is not immune to the original issue. Therefore, it is imperative that stores override the `create` method with an implementation that adheres to the required uniqueness semantics, thereby effectively mitigating the risk of session ID collisions. - -# 0.11.1 - -- Ensure `session.set_expiry` updates record. #175 -- Provide `signed` and `private` features, enabling signing and encryption respectively. #157 - -# 0.11.0 - -- Uses slices when encoding and decoding `Id`. #159 - -**Breaking Changes** - -- Removes `IdError` type in favor of using `base64::DecodeSliceError`. #159 -- Provides the same changes as 0.10.4, without breaking SemVer. -- Updates `base64` to `0.22.0`. - -# ~0.10.4~ **Yanked:** SemVer breaking - -- Revert introduction of lifetime parameter; use static lifetime directly - -This ensures that the changes introduced in `0.10.3` do not break SemVer. - -Please note that `0.10.3` has been yanked in accordance with cargo guidelines. - -# ~0.10.3~ **Yanked:** SemVer breaking - -- Improve session config allocation footprint #158 - -# 0.10.2 - -- Ensure "Path" and "Domain" are set on removal cookie #154 - -# 0.10.1 - -- Ensure `Expires: Session` #149 - -# 0.10.0 - -**Breaking Changes** - -- Improve session ID #141 -- Relocate previously bundled stores #145 -- Move service out of core #146 - -Session IDs are now represetned as base64-encoded `i128`s, boast 128 bits of entropy, and are shorter, saving network bandwidth and improving the secure nature of sessions. - -We no longer bundle session stores via feature flags and as such applications must be updated to require the stores directly. For example, applications that use the `tower-sessions-sqlx-store` should update their `Cargo.toml` like so: - -```toml -tower-sessions = "0.10.0" -tower-sessions-sqlx-store = { version = "0.10.0", features = ["sqlite"] } -``` - -Assuming a SQLite store, as an example. - -Furthermore, imports will also need to be updated accordingly. For example: - -```rust -use std::net::SocketAddr; - -use axum::{response::IntoResponse, routing::get, Router}; -use serde::{Deserialize, Serialize}; -use time::Duration; -use tower_sessions::{session_store::ExpiredDeletion, Expiry, Session, SessionManagerLayer}; -use tower_sessions_sqlx_store::{sqlx::SqlitePool, SqliteStore}; - -const COUNTER_KEY: &str = "counter"; - -#[derive(Serialize, Deserialize, Default)] -struct Counter(usize); - -#[tokio::main] -async fn main() -> Result<(), Box> { - let pool = SqlitePool::connect("sqlite::memory:").await?; - let session_store = SqliteStore::new(pool); - session_store.migrate().await?; - - let deletion_task = tokio::task::spawn( - session_store - .clone() - .continuously_delete_expired(tokio::time::Duration::from_secs(60)), - ); - - let session_layer = SessionManagerLayer::new(session_store) - .with_secure(false) - .with_expiry(Expiry::OnInactivity(Duration::seconds(10))); - - let app = Router::new().route("/", get(handler)).layer(session_layer); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - let listener = tokio::net::TcpListener::bind(&addr).await?; - axum::serve(listener, app.into_make_service()).await?; - - deletion_task.await??; - - Ok(()) -} - -async fn handler(session: Session) -> impl IntoResponse { - let counter: Counter = session.get(COUNTER_KEY).await.unwrap().unwrap_or_default(); - session.insert(COUNTER_KEY, counter.0 + 1).await.unwrap(); - format!("Current count: {}", counter.0) -} -``` - -Finally, the service itself has been moved out of the core crate, which makes this crate smaller as well as establishes better boundaries between code. - -Thank you for bearing with us: we are approaching longer term stability and aim to minimize churn going forward as we begin to move toward a 1.0 release. - -# 0.9.1 - -- Ensure `clear` works before record loading. #134 - -# 0.9.0 - -**Breakiung Changes** - -- Make service infallible. #132 - -This updates the service such that it always returns a response directly. In practice this means that e.g. `axum` applications no longer need the `HandleErrorLayer` and instead can use the layer directly. Note that if you use other fallible `tower` middleware, you will still need to use `HandleErrorLayer`. - -As such we've also remove the `MissingCookies` and `MissingId` variants from the session error enum. - -# 0.8.2 - -- Derive `PartialEq` for `Record`. #125 - -# 0.8.1 - -- Allow constructing `RedisStore` from `RedisPool`. #122 - -# 0.8.0 - -**Breaking Changes** - -- Lazy sessions. #112 - -Among other things, session methods are now entirely async, meaning applications must be updated to await these methods in order to migrate. - -Separately, `SessionStore` has been updated to use a `Record` intermediary. As such, `SessionStore` implementations must be updated accordingly. - -Session stores now use a concrete error type that must be used in implementations of `SessionStore`. - -The `secure` cookie attribute now defaults to `true`. - -# 0.7.0 - -**Breaking Changes** - -- Bump `axum-core` to 0.4.0, `http` to 1.0, `tower-cookies` to 0.10.0. #107 - -This brings `tower-cookies` up-to-date which includes an update to the `cookies` crate. - -# 0.6.0 - -**Breaking Changes** - -- Remove concurrent shared memory access support; this may also address some performance degradations. #91 -- Related to shared memory support, we also remove `replace_if_equal`, as it is no longer relevant. #91 - -**Other Changes** - -- Allow setting up table and schema name for Postgres. #93 - -# 0.5.1 - -- Only delete from session store if we have a session cookie. #90 - -# 0.5.0 - -**Breaking Changes** - -- Use a default session name of "id" to avoid fingerprinting, as per https://cheatsheetseries.owasp.org/cheatsheets/Session_Management_Cheat_Sheet.html#session-id-name-fingerprinting. - -Note that applications using the old default, "tower.sid", may continue to do so without disruption by specifying [`with_name("tower.sid")`](https://docs.rs/tower-sessions/latest/tower_sessions/service/struct.SessionManagerLayer.html#method.with_name). - -# 0.4.3 - -## **Important Security Fix** - -If your application uses `MokaStore` or `MemoryStore`, please update immediately to ensure proper server-side handling of expired sessions. - -**Other Changes** - -- Make `HttpOnly` configurable. #81 - -# 0.4.2 - -- Provide tracing instrumentation. -- Ensure non-negative max-age. #79 - -# 0.4.1 - -- Fix lifecycle state persisting in stores when it should not. #71 - -# 0.4.0 - -**Breaking Changes** - -- Sessions are serialized and deserialized from stores directly and `SessionRecord` is removed. -- Expiration time has been replaced with an expiry type. -- Drop session-prefix from session types. -- The session `modified` methid is renamed to `is_modified`. -- Session active semantic is now defined by stores and the `active` method removed. -- Service now contains session configuration and `CookieConfig` is removed. -- Deletion task is now provided via the `deletion-task` feature flag. - -# 0.3.3 - -- Ensure loaded sessions are removed whenever they can be; do not couple removal with session saving. - -# 0.3.2 - -- Implement reference-counted garbage collection for loaded sessions. #52 -- Make `SessionId`'s UUID public. #53 - -# 0.3.1 - -- Use `DashMap` entry API to address data race introduced by dashmap. #41 - -# 0.3.0 - -**Breaking Changes** - -- `tokio` feature flag is now `tokio-rt`. -- Session IDs are returned as references now. - -**Other Changes** - -- Update `fred` to 7.0.0. -- Track loaded sessions to enable concurrent access. #37 - -# 0.2.4 - -- Fix session saving and loading potential data race. #36 - -# 0.2.3 - -- Fix setting of modified in `replace_if_equal`. - -# 0.2.2 - -- Lift `Debug` constraint on `CachingSessionStore`. -- Run caching store save and load ops concurrently. #25 - -# 0.2.1 - -- Fix clearing session's data is not persisted. #22 - -# 0.2.0 - -**Breaking Changes** - -- Renamed store error variants for consistency (SqlxStoreError, RedisStoreError). #18 -- Moved MySQL `expiration_time` column to `timestamp(6), for microsecond resolution. #14 -- Replaced `Session.with_max_age` with `set_expiration_time` and `set_expiration_time_from_max_age`, allowing applications to control session durations dynamically. #7 - -**Other Changes** - -- Provide layered caching via `CachingSessionStore` #8 -- Provide a Moka store #6 (Thank you @and-reas-se!) -- Provide a MongoDB store #5 (Thank you @JustMangoT!) - -# 0.1.0 - -- Initial release :tada: diff --git a/Cargo.toml b/Cargo.toml index 0c73a64..fbd5249 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = [".", "memory-store", "tower-sessions-core"] +members = [".", "tower-sesh-core", "memory-store"] resolver = "2" [workspace.package] @@ -7,16 +7,16 @@ version = "0.13.0" edition = "2021" authors = ["Max Countryman "] license = "MIT" -homepage = "https://github.com/maxcountryman/tower-sessions" +homepage = "https://github.com/maxcountryman/tower-sesh" keywords = ["axum", "session", "sessions", "cookie", "tower"] categories = ["asynchronous", "network-programming", "web-programming"] -repository = "https://github.com/maxcountryman/tower-sessions" -documentation = "https://docs.rs/tower-sessions" +repository = "https://github.com/maxcountryman/tower-sesh" +documentation = "https://docs.rs/tower-sesh" readme = "README.md" [package] -name = "tower-sessions" -description = "🥠 Sessions as a `tower` and `axum` middleware." +name = "tower-sesh" +description = "Cookie sessions as a `tower` and `axum` middleware." version.workspace = true edition.workspace = true authors.workspace = true @@ -30,75 +30,48 @@ readme.workspace = true [package.metadata.docs.rs] all-features = true -rustdoc-args = ["--cfg", "docsrs"] - -[features] -default = ["axum-core", "memory-store"] -axum-core = ["tower-sessions-core/axum-core"] -memory-store = ["tower-sessions-memory-store"] -signed = ["tower-cookies/signed"] -private = ["tower-cookies/private"] +rustdoc-args = ["--cfg", "docsrs", "--generate-link-to-definition"] +cargo-args = ["-Zunstable-options", "-Zrustdoc-scrape-examples"] [workspace.dependencies] -tower-sessions = { version = "=0.13.0", path = ".", default-features = false } - -tower-sessions-core = { version = "=0.13.0", path = "tower-sessions-core", default-features = false } -tower-sessions-memory-store = { version = "=0.13.0", path = "memory-store" } +tower-sesh = { version = "=0.13.0", path = ".", features = ["memory-store", "extractor"] } +tower-sesh-core = { version = "=0.13.0", path = "tower-sesh-core" } +tower-sesh-memory-store = { version = "=0.13.0", path = "memory-store" } -async-trait = "0.1.74" -parking_lot = { version = "0.12.1", features = ["serde"] } -rmp-serde = { version = "1.1.2" } -serde = "1.0.192" -thiserror = "1.0.50" time = "0.3.30" -tokio = { version = "1.32.0", default-features = false, features = ["sync"] } +tokio = { version = "1.32.0", default-features = false } + +[features] +memory-store = ["tower-sesh-memory-store"] +extractor = ["dep:axum-core", "dep:async-trait"] [dependencies] -async-trait = "0.1.73" +async-trait = { version = "0.1.74", optional = true } +axum-core = { version = "0.4", optional = true } +cookie = "0.18.1" http = "1.0" -tokio = { version = "1.32.0", features = ["sync"] } +pin-project-lite = "0.2.14" +time = { workspace = true, features = ["serde"] } tower-layer = "0.3.2" tower-service = "0.3.2" -tower-sessions-core = { workspace = true } -tower-sessions-memory-store = { workspace = true, optional = true } +tower-sesh-core = { workspace = true } +tower-sesh-memory-store = { workspace = true, optional = true } tracing = { version = "0.1.40", features = ["log"] } -tower-cookies = "0.10.0" -time = { version = "0.3.29", features = ["serde"] } [dev-dependencies] -async-trait = "0.1.74" anyhow = "1" axum = "0.7.1" axum-core = "0.4.0" -futures = { version = "0.3.28", default-features = false, features = [ - "async-await", -] } http = "1.0" http-body-util = "0.1" hyper = "1.0" -reqwest = { version = "0.12.3", default-features = false, features = [ - "rustls-tls", -] } -serde = "1.0.192" -time = "0.3.30" +time = { workspace = true } tokio = { version = "1.32.0", features = ["full"] } -tokio-test = "0.4.3" tower = { version = "0.5.0", features = ["util"] } -tower-cookies = "0.10.0" -tower-sessions-core = { workspace = true, features = ["deletion-task"] } +tower-sesh-core = { workspace = true } +tower-sesh-memory-store = { workspace = true } +tower-sesh = { workspace = true } [[example]] name = "counter" -required-features = ["axum-core", "memory-store"] - -[[example]] -name = "counter-extractor" -required-features = ["axum-core", "memory-store"] - -[[example]] -name = "strongly-typed" -required-features = ["axum-core"] - -[[example]] -name = "signed" -required-features = ["signed", "memory-store"] +required-features = ["memory-store"] diff --git a/LICENSE b/LICENSE index c8e38d4..1938c1c 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2022 Max Countryman +Copyright (c) 2024 Charles Edward Gagnon Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index 8ce58d1..00f07e9 100644 --- a/README.md +++ b/README.md @@ -1,147 +1,37 @@ -

- tower-sessions -

+# tower-sesh +An opinionated cookie session middleware for `tower` services. -

- 🥠 Sessions as a `tower` and `axum` middleware. -

+## Comparison with `tower-sessions` +`tower-sessions` tries to follow the design of `django`'s session middleware. As a consequence, +_every request going through the middleware_ does the following: +- allocate multiple `HashMap`s, +- Use dynamic dispatch for futures (using `Pin>`) +- make extensive use of the `Arc>` magic sauce. -
- - - - - - - - - - - - -
- -## 🎨 Overview - -This crate provides sessions, key-value pairs associated with a site -visitor, as a `tower` middleware. - -It offers: - -- **Pluggable Storage Backends:** Bring your own backend simply by - implementing the `SessionStore` trait, fully decoupling sessions from their - storage. -- **Minimal Overhead**: Sessions are only loaded from their backing stores - when they're actually used and only in e.g. the handler they're used in. - That means this middleware can be installed anywhere in your route - graph with minimal overhead. -- **An `axum` Extractor for `Session`:** Applications built with `axum` - can use `Session` as an extractor directly in their handlers. This makes - using sessions as easy as including `Session` in your handler. -- **Simple Key-Value Interface:** Sessions offer a key-value interface that - supports native Rust types. So long as these types are `Serialize` and can - be converted to JSON, it's straightforward to insert, get, and remove any - value. -- **Strongly-Typed Sessions:** Strong typing guarantees are easy to layer on - top of this foundational key-value interface. - -This crate's session implementation is inspired by the [Django sessions middleware](https://docs.djangoproject.com/en/4.2/topics/http/sessions) and it provides a transliteration of those semantics. - -### Session stores +We don't do that here. +## Session stores Session data persistence is managed by user-provided types that implement `SessionStore`. What this means is that applications can and should implement session stores to fit their specific needs. -That said, a number of session store implmentations already exist and may be +That said, a number of session store implementations already exist and may be useful starting points. -| Crate | Persistent | Description | -| ---------------------------------------------------------------------------------------------------------------- | ---------- | ----------------------------------------------------------- | -| [`tower-sessions-dynamodb-store`](https://github.com/necrobious/tower-sessions-dynamodb-store) | Yes | DynamoDB session store | -| [`tower-sessions-firestore-store`](https://github.com/AtTheTavern/tower-sessions-firestore-store) | Yes | Firestore session store | -| [`tower-sessions-libsql-store`](https://github.com/daybowbow-dev/tower-sessions-libsql-store) | Yes | libSQL session store | -| [`tower-sessions-mongodb-store`](https://github.com/maxcountryman/tower-sessions-stores/tree/main/mongodb-store) | Yes | MongoDB session store | -| [`tower-sessions-moka-store`](https://github.com/maxcountryman/tower-sessions-stores/tree/main/moka-store) | No | Moka session store | -| [`tower-sessions-redis-store`](https://github.com/maxcountryman/tower-sessions-stores/tree/main/redis-store) | Yes | Redis via `fred` session store | -| [`tower-sessions-rorm-store`](https://github.com/rorm-orm/tower-sessions-rorm-store) | Yes | SQLite, Postgres and Mysql session store provided by `rorm` | -| [`tower-sessions-rusqlite-store`](https://github.com/patte/tower-sessions-rusqlite-store) | Yes | Rusqlite session store | -| [`tower-sessions-sled-store`](https://github.com/Zatzou/tower-sessions-sled-store) | Yes | Sled session store | -| [`tower-sessions-sqlx-store`](https://github.com/maxcountryman/tower-sessions-stores/tree/main/sqlx-store) | Yes | SQLite, Postgres, and MySQL session stores | -| [`tower-sessions-surrealdb-store`](https://github.com/rynoV/tower-sessions-surrealdb-store) | Yes | SurrealDB session store | +| Crate | Persistent | Description | +| ---------------------------------------------------------------------------------| ---------- | ------------------------- | +| [`tower-sesh-redis-store`](https://github.com/carloskiki/tower-sesh-redis-store) | Yes | Redis using `redis` crate | Have a store to add? Please open a PR adding it. -### User session management - -To facilitate authentication and authorization, we've built [`axum-login`](https://github.com/maxcountryman/axum-login) on top of this crate. Please check it out if you're looking for a generalized auth solution. - -## 📦 Install - -To use the crate in your project, add the following to your `Cargo.toml` file: +## Usage +This crate is not published on crates.io. You need to add it as a git dependency. ```toml [dependencies] -tower-sessions = "0.13.0" +tower-sesh = { git = "https://github.com/carloskiki/tower-sesh.git" } ``` -## 🤸 Usage - -### `axum` Example - -```rust -use std::net::SocketAddr; - -use axum::{response::IntoResponse, routing::get, Router}; -use serde::{Deserialize, Serialize}; -use time::Duration; -use tower_sessions::{Expiry, MemoryStore, Session, SessionManagerLayer}; - -const COUNTER_KEY: &str = "counter"; - -#[derive(Default, Deserialize, Serialize)] -struct Counter(usize); - -async fn handler(session: Session) -> impl IntoResponse { - let counter: Counter = session.get(COUNTER_KEY).await.unwrap().unwrap_or_default(); - session.insert(COUNTER_KEY, counter.0 + 1).await.unwrap(); - format!("Current count: {}", counter.0) -} - -#[tokio::main] -async fn main() { - let session_store = MemoryStore::default(); - let session_layer = SessionManagerLayer::new(session_store) - .with_secure(false) - .with_expiry(Expiry::OnInactivity(Duration::seconds(10))); - - let app = Router::new().route("/", get(handler)).layer(session_layer); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); - axum::serve(listener, app.into_make_service()) - .await - .unwrap(); -} -``` - -You can find this [example][counter-example] as well as other example projects in the [example directory][examples]. - -> [!NOTE] -> See the [crate documentation][docs] for more usage information. - -## 🦺 Safety - -This crate uses `#![forbid(unsafe_code)]` to ensure everything is implemented in 100% safe Rust. - -## 🛟 Getting Help - -We've put together a number of [examples][examples] to help get you started. You're also welcome to [open a discussion](https://github.com/maxcountryman/tower-sessions/discussions/new?category=q-a) and ask additional questions you might have. - -## 👯 Contributing - -We appreciate all kinds of contributions, thank you! +## Contributing -[counter-example]: https://github.com/maxcountryman/tower-sessions/tree/main/examples/counter.rs -[examples]: https://github.com/maxcountryman/tower-sessions/tree/main/examples -[docs]: https://docs.rs/tower-sessions +All contributions are welcome! All are licensed under the MIT license. diff --git a/examples/counter-extractor.rs b/examples/counter-extractor.rs deleted file mode 100644 index e55e12a..0000000 --- a/examples/counter-extractor.rs +++ /dev/null @@ -1,48 +0,0 @@ -use std::net::SocketAddr; - -use async_trait::async_trait; -use axum::{extract::FromRequestParts, response::IntoResponse, routing::get, Router}; -use http::request::Parts; -use serde::{Deserialize, Serialize}; -use time::Duration; -use tower_sessions::{Expiry, MemoryStore, Session, SessionManagerLayer}; - -const COUNTER_KEY: &str = "counter"; - -#[derive(Default, Deserialize, Serialize)] -struct Counter(usize); - -#[async_trait] -impl FromRequestParts for Counter -where - S: Send + Sync, -{ - type Rejection = (http::StatusCode, &'static str); - - async fn from_request_parts(req: &mut Parts, state: &S) -> Result { - let session = Session::from_request_parts(req, state).await?; - let counter: Counter = session.get(COUNTER_KEY).await.unwrap().unwrap_or_default(); - session.insert(COUNTER_KEY, counter.0 + 1).await.unwrap(); - Ok(counter) - } -} - -#[tokio::main] -async fn main() { - let session_store = MemoryStore::default(); - let session_layer = SessionManagerLayer::new(session_store) - .with_secure(false) - .with_expiry(Expiry::OnInactivity(Duration::seconds(10))); - - let app = Router::new().route("/", get(handler)).layer(session_layer); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); - axum::serve(listener, app.into_make_service()) - .await - .unwrap(); -} - -async fn handler(counter: Counter) -> impl IntoResponse { - format!("Current count: {}", counter.0) -} diff --git a/examples/counter.rs b/examples/counter.rs index d39a9df..b027ab7 100644 --- a/examples/counter.rs +++ b/examples/counter.rs @@ -1,27 +1,49 @@ use std::net::SocketAddr; use axum::{response::IntoResponse, routing::get, Router}; -use serde::{Deserialize, Serialize}; use time::Duration; -use tower_sessions::{Expiry, MemoryStore, Session, SessionManagerLayer}; +use tower_sesh::{Expires, Expiry, MemoryStore, Session, SessionManagerLayer}; -const COUNTER_KEY: &str = "counter"; - -#[derive(Default, Deserialize, Serialize)] +#[derive(Clone, Copy, Debug)] struct Counter(usize); -async fn handler(session: Session) -> impl IntoResponse { - let counter: Counter = session.get(COUNTER_KEY).await.unwrap().unwrap_or_default(); - session.insert(COUNTER_KEY, counter.0 + 1).await.unwrap(); - format!("Current count: {}", counter.0) +impl Expires for Counter { + fn expires(&self) -> Expiry { + Expiry::OnInactivity(Duration::seconds(10)) + } +} + +async fn handler(session: Session>) -> impl IntoResponse { + let value = if let Some(counter_state) = session.clone().load::().await.unwrap() { + // We loaded the session, let's update the counter. + match counter_state + .update(|counter| counter.0 += 1) + .await + .unwrap() + { + Some(new_state) => new_state.data().0, + None => { + // The session has expired while we were updating it, let's create a new one. + session.create(Counter(0)).await.unwrap(); + 0 + } + } + } else { + // No session found, let's create a new one. + session.create(Counter(0)).await.unwrap(); + 0 + }; + + format!("Current count: {}", value) } #[tokio::main] async fn main() { - let session_store = MemoryStore::default(); - let session_layer = SessionManagerLayer::new(session_store) - .with_secure(false) - .with_expiry(Expiry::OnInactivity(Duration::seconds(10))); + let session_store: MemoryStore = MemoryStore::default(); + let session_layer = SessionManagerLayer { + store: session_store, + config: Default::default(), + }; let app = Router::new().route("/", get(handler)).layer(session_layer); diff --git a/examples/signed.rs b/examples/signed.rs deleted file mode 100644 index cb5d687..0000000 --- a/examples/signed.rs +++ /dev/null @@ -1,36 +0,0 @@ -use std::net::SocketAddr; - -use axum::{response::IntoResponse, routing::get, Router}; -use serde::{Deserialize, Serialize}; -use time::Duration; -use tower_sessions::{cookie::Key, Expiry, MemoryStore, Session, SessionManagerLayer}; - -const COUNTER_KEY: &str = "counter"; - -#[derive(Default, Deserialize, Serialize)] -struct Counter(usize); - -async fn handler(session: Session) -> impl IntoResponse { - let counter: Counter = session.get(COUNTER_KEY).await.unwrap().unwrap_or_default(); - session.insert(COUNTER_KEY, counter.0 + 1).await.unwrap(); - format!("Current count: {}", counter.0) -} - -#[tokio::main] -async fn main() { - let key = Key::generate(); // This is only used for demonstration purposes; provide a proper - // cryptographic key in a real application. - let session_store = MemoryStore::default(); - let session_layer = SessionManagerLayer::new(session_store) - .with_secure(false) - .with_expiry(Expiry::OnInactivity(Duration::seconds(10))) - .with_signed(key); - - let app = Router::new().route("/", get(handler)).layer(session_layer); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); - axum::serve(listener, app.into_make_service()) - .await - .unwrap(); -} diff --git a/examples/strongly-typed.rs b/examples/strongly-typed.rs deleted file mode 100644 index 46d02c0..0000000 --- a/examples/strongly-typed.rs +++ /dev/null @@ -1,119 +0,0 @@ -use std::{fmt, net::SocketAddr}; - -use async_trait::async_trait; -use axum::{extract::FromRequestParts, response::IntoResponse, routing::get, Router}; -use http::{request::Parts, StatusCode}; -use serde::{Deserialize, Serialize}; -use time::OffsetDateTime; -use tower_sessions::{MemoryStore, Session, SessionManagerLayer}; - -#[derive(Clone, Deserialize, Serialize)] -struct GuestData { - pageviews: usize, - first_seen: OffsetDateTime, - last_seen: OffsetDateTime, -} - -impl Default for GuestData { - fn default() -> Self { - Self { - pageviews: 0, - first_seen: OffsetDateTime::now_utc(), - last_seen: OffsetDateTime::now_utc(), - } - } -} - -struct Guest { - session: Session, - guest_data: GuestData, -} - -impl Guest { - const GUEST_DATA_KEY: &'static str = "guest.data"; - - fn first_seen(&self) -> OffsetDateTime { - self.guest_data.first_seen - } - - fn last_seen(&self) -> OffsetDateTime { - self.guest_data.last_seen - } - - fn pageviews(&self) -> usize { - self.guest_data.pageviews - } - - async fn mark_pageview(&mut self) { - self.guest_data.pageviews += 1; - Self::update_session(&self.session, &self.guest_data).await - } - - async fn update_session(session: &Session, guest_data: &GuestData) { - session - .insert(Self::GUEST_DATA_KEY, guest_data.clone()) - .await - .unwrap() - } -} - -impl fmt::Display for Guest { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Guest") - .field("pageviews", &self.pageviews()) - .field("first_seen", &self.first_seen()) - .field("last_seen", &self.last_seen()) - .finish() - } -} - -#[async_trait] -impl FromRequestParts for Guest -where - S: Send + Sync, -{ - type Rejection = (StatusCode, &'static str); - - async fn from_request_parts(req: &mut Parts, state: &S) -> Result { - let session = Session::from_request_parts(req, state).await?; - - let mut guest_data: GuestData = session - .get(Self::GUEST_DATA_KEY) - .await - .unwrap() - .unwrap_or_default(); - - guest_data.last_seen = OffsetDateTime::now_utc(); - - Self::update_session(&session, &guest_data).await; - - Ok(Self { - session, - guest_data, - }) - } -} - -// This demonstrates a `Guest` extractor, but we could have any number of -// namespaced, strongly-typed "buckets" like `Guest` in the same session. -// -// Use cases could include buckets for site preferences, analytics, -// feature flags, etc. -async fn handler(mut guest: Guest) -> impl IntoResponse { - guest.mark_pageview().await; - format!("{}", guest) -} - -#[tokio::main] -async fn main() { - let session_store = MemoryStore::default(); - let session_layer = SessionManagerLayer::new(session_store).with_secure(false); - - let app = Router::new().route("/", get(handler)).layer(session_layer); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); - axum::serve(listener, app.into_make_service()) - .await - .unwrap(); -} diff --git a/memory-store/Cargo.toml b/memory-store/Cargo.toml index b32388f..608e60a 100644 --- a/memory-store/Cargo.toml +++ b/memory-store/Cargo.toml @@ -1,6 +1,6 @@ [package] -name = "tower-sessions-memory-store" -description = "Memory session store. Not for direct use; see the `tower-sessions` crate for details." +name = "tower-sesh-memory-store" +description = "Memory session store. Not for direct use; see the `tower-sesh` crate for details." documentation.workspace = true version.workspace = true license.workspace = true @@ -9,10 +9,11 @@ authors.workspace = true repository.workspace = true [dependencies] -tower-sessions-core = { workspace = true } -async-trait = { workspace = true } -time = { workspace = true } +tower-sesh-core = { workspace = true } tokio = { workspace = true } +time = { workspace = true } +rand = "0.8.5" [dev-dependencies] -tower-sessions = { workspace = true } +tower-sesh = { workspace = true } +tokio = { workspace = true, features = ["rt", "macros"] } diff --git a/memory-store/src/lib.rs b/memory-store/src/lib.rs index c86e819..97d2933 100644 --- a/memory-store/src/lib.rs +++ b/memory-store/src/lib.rs @@ -1,134 +1,183 @@ -use std::{collections::HashMap, sync::Arc}; +use std::sync::Mutex; +use std::{collections::HashMap, convert::Infallible, sync::Arc}; -use async_trait::async_trait; +use std::fmt::Debug; use time::OffsetDateTime; -use tokio::sync::Mutex; -use tower_sessions_core::{ - session::{Id, Record}, - session_store, SessionStore, -}; +use tower_sesh_core::{expires::Expires, Expiry, Id, SessionStore}; /// A session store that lives only in memory. /// /// This is useful for testing but not recommended for real applications. /// +/// The store manages the expiry of the sessions with respect to UTC time. No cleanup is done for +/// the expired sessions untile the are loaded. +/// /// # Examples /// /// ```rust -/// use tower_sessions::MemoryStore; -/// MemoryStore::default(); +/// use tower_sesh_memory_store::MemoryStore; +/// +/// struct User { +/// name: String, +/// age: u8, +/// } +/// +/// let store: MemoryStore = MemoryStore::default(); /// ``` -#[derive(Clone, Debug, Default)] -pub struct MemoryStore(Arc>>); - -#[async_trait] -impl SessionStore for MemoryStore { - async fn create(&self, record: &mut Record) -> session_store::Result<()> { - let mut store_guard = self.0.lock().await; - while store_guard.contains_key(&record.id) { - // Session ID collision mitigation. - record.id = Id::default(); +#[derive(Debug)] +pub struct MemoryStore(Arc>>>); + +impl Default for MemoryStore { + fn default() -> Self { + MemoryStore(Default::default()) + } +} + +impl Clone for MemoryStore { + fn clone(&self) -> Self { + MemoryStore(self.0.clone()) + } +} + +#[derive(Debug, Clone)] +struct Value { + data: R, + // Needed because if the expiry date is set to `OnInactivity`, we need to know whether the + // session is active or not. + expiry_date: Option, +} + +impl Value { + /// Create a new `MemoryStore`. + pub fn new(data: R) -> Self { + let expiry_date = match data.expires() { + Expiry::OnSessionEnd => None, + Expiry::OnInactivity(duration) => Some(OffsetDateTime::now_utc() + duration), + Expiry::AtDateTime(offset_date_time) => Some(offset_date_time), + }; + + Value { data, expiry_date } + } +} + +impl SessionStore for MemoryStore +where + R: Expires + Send + Sync + Clone, +{ + type Error = Infallible; + + async fn create(&mut self, record: &R) -> Result { + let mut id = random_id(); + let mut store = self.0.lock().unwrap(); + while store.contains_key(&id) { + // If the ID already exists, generate a new one + id = random_id(); + } + + let value = Value::new(record.clone()); + + store.insert(id, value); + Ok(id) + } + + async fn save(&mut self, id: &Id, record: &R) -> Result { + let mut store = self.0.lock().unwrap(); + if store.contains_key(id) { + let value = Value::new(record.clone()); + store.insert(*id, value); + Ok(true) + } else { + Ok(false) } - store_guard.insert(record.id, record.clone()); - Ok(()) } - async fn save(&self, record: &Record) -> session_store::Result<()> { - self.0.lock().await.insert(record.id, record.clone()); + async fn save_or_create(&mut self, id: &Id, record: &R) -> Result<(), Self::Error> { + let mut store = self.0.lock().unwrap(); + let value = Value::new(record.clone()); + store.insert(*id, value); Ok(()) } - async fn load(&self, session_id: &Id) -> session_store::Result> { - Ok(self - .0 - .lock() - .await - .get(session_id) - .filter(|Record { expiry_date, .. }| is_active(*expiry_date)) - .cloned()) + async fn load(&mut self, id: &Id) -> Result, Self::Error> { + let mut store = self.0.lock().unwrap(); + + let Some(value) = store.get(id) else { + return Ok(None); + }; + Ok(match value.expiry_date { + Some(expiry_date) if expiry_date > OffsetDateTime::now_utc() => { + store.remove(id); + None + } + _ => Some(value.data.clone()), + }) + } + + async fn delete(&mut self, id: &Id) -> Result { + let mut store = self.0.lock().unwrap(); + Ok(store.remove(id).is_some()) } - async fn delete(&self, session_id: &Id) -> session_store::Result<()> { - self.0.lock().await.remove(session_id); - Ok(()) + async fn cycle_id(&mut self, old_id: &Id) -> Result, Self::Error> { + let mut store = self.0.lock().unwrap(); + if let Some(record) = store.remove(old_id) { + let mut new_id = random_id(); + while store.contains_key(&new_id) { + // If the ID already exists, generate a new one + new_id = random_id(); + } + store.insert(new_id, record); + Ok(Some(new_id)) + } else { + Ok(None) + } } } -fn is_active(expiry_date: OffsetDateTime) -> bool { - expiry_date > OffsetDateTime::now_utc() +fn random_id() -> Id { + use rand::prelude::*; + let id_val = rand::thread_rng().gen(); + Id(id_val) } #[cfg(test)] mod tests { - use time::Duration; - use super::*; + use tower_sesh_core::SessionStore; - #[tokio::test] - async fn test_create() { - let store = MemoryStore::default(); - let mut record = Record { - id: Default::default(), - data: Default::default(), - expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30), - }; - assert!(store.create(&mut record).await.is_ok()); + + #[derive(Debug, Clone)] + struct SimpleUser { + age: u8, } - #[tokio::test] - async fn test_save() { - let store = MemoryStore::default(); - let record = Record { - id: Default::default(), - data: Default::default(), - expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30), - }; - assert!(store.save(&record).await.is_ok()); - } + impl Expires for SimpleUser {} #[tokio::test] - async fn test_load() { - let store = MemoryStore::default(); - let mut record = Record { - id: Default::default(), - data: Default::default(), - expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30), - }; - store.create(&mut record).await.unwrap(); - let loaded_record = store.load(&record.id).await.unwrap(); - assert_eq!(Some(record), loaded_record); - } + async fn round_trip() { + let mut store: MemoryStore = MemoryStore::default(); - #[tokio::test] - async fn test_delete() { - let store = MemoryStore::default(); - let mut record = Record { - id: Default::default(), - data: Default::default(), - expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30), - }; - store.create(&mut record).await.unwrap(); - assert!(store.delete(&record.id).await.is_ok()); - assert_eq!(None, store.load(&record.id).await.unwrap()); - } + let id = store.create(&SimpleUser { + age: 20, + }).await.unwrap(); - #[tokio::test] - async fn test_create_id_collision() { - let store = MemoryStore::default(); - let expiry_date = OffsetDateTime::now_utc() + Duration::minutes(30); - let mut record1 = Record { - id: Default::default(), - data: Default::default(), - expiry_date, - }; - let mut record2 = Record { - id: Default::default(), - data: Default::default(), - expiry_date, - }; - store.create(&mut record1).await.unwrap(); - record2.id = record1.id; // Set the same ID for record2 - store.create(&mut record2).await.unwrap(); - assert_ne!(record1.id, record2.id); // IDs should be different + let mut user = store.load(&id).await.unwrap().unwrap(); + assert_eq!(20, user.age); + + user.age = 30; + assert!(store.save(&id, &user).await.unwrap()); + + let user = store.load(&id).await.unwrap().unwrap(); + assert_eq!(30, user.age); + + let new_id = store.cycle_id(&id).await.unwrap().unwrap(); + assert_ne!(id, new_id); + + assert!(store.load(&id).await.unwrap().is_none()); + let user = store.load(&new_id).await.unwrap().unwrap(); + assert_eq!(30, user.age); + + assert!(store.delete(&new_id).await.unwrap()); + assert!(store.load(&new_id).await.unwrap().is_none()); } } diff --git a/src/lib.rs b/src/lib.rs index faa7fc6..25b1c52 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,432 +1,5 @@ -//! # Overview -//! -//! This crate provides sessions, key-value pairs associated with a site -//! visitor, as a [`tower`](https://docs.rs/tower/latest/tower/) middleware. -//! -//! It offers: -//! -//! - **Pluggable Storage Backends:** Bring your own backend simply by -//! implementing the [`SessionStore`] trait, fully decoupling sessions from -//! their storage. -//! - **Minimal Overhead**: Sessions are only loaded from their backing stores -//! when they're actually used and only in e.g. the handler they're used in. -//! That means this middleware can be installed at any point in your route -//! graph with minimal overhead. -//! - **An `axum` Extractor for [`Session`]:** Applications built with `axum` -//! can use `Session` as an extractor directly in their handlers. This makes -//! using sessions as easy as including `Session` in your handler. -//! - **Simple Key-Value Interface:** Sessions offer a key-value interface that -//! supports native Rust types. So long as these types are `Serialize` and can -//! be converted to JSON, it's straightforward to insert, get, and remove any -//! value. -//! - **Strongly-Typed Sessions:** Strong typing guarantees are easy to layer on -//! top of this foundational key-value interface. -//! -//! This crate's session implementation is inspired by the [Django sessions middleware](https://docs.djangoproject.com/en/4.2/topics/http/sessions) and it provides a transliteration of those semantics. -//! ### Session stores -//! -//! Session data persistence is managed by user-provided types that implement -//! [`SessionStore`]. What this means is that applications can and should -//! implement session stores to fit their specific needs. -//! -//! That said, a number of session store implmentations already exist and may be -//! useful starting points. -//! -//! | Crate | Persistent | Description | -//! | ---------------------------------------------------------------------------------------------------------------- | ---------- | ------------------------------------------ | -//! | [`tower-sessions-dynamodb-store`](https://github.com/necrobious/tower-sessions-dynamodb-store) | Yes | DynamoDB session store | -//! | [`tower-sessions-firestore-store`](https://github.com/AtTheTavern/tower-sessions-firestore-store) | Yes | Firestore session store | -//! | [`tower-sessions-libsql-store`](https://github.com/daybowbow-dev/tower-sessions-libsql-store) | Yes | libSQL session store | -//! | [`tower-sessions-mongodb-store`](https://github.com/maxcountryman/tower-sessions-stores/tree/main/mongodb-store) | Yes | MongoDB session store | -//! | [`tower-sessions-moka-store`](https://github.com/maxcountryman/tower-sessions-stores/tree/main/moka-store) | No | Moka session store | -//! | [`tower-sessions-redis-store`](https://github.com/maxcountryman/tower-sessions-stores/tree/main/redis-store) | Yes | Redis via `fred` session store | -//! | [`tower-sessions-rusqlite-store`](https://github.com/patte/tower-sessions-rusqlite-store) | Yes | Rusqlite session store | -//! | [`tower-sessions-sled-store`](https://github.com/Zatzou/tower-sessions-sled-store) | Yes | Sled session store | -//! | [`tower-sessions-sqlx-store`](https://github.com/maxcountryman/tower-sessions-stores/tree/main/sqlx-store) | Yes | SQLite, Postgres, and MySQL session stores | -//! | [`tower-sessions-surrealdb-store`](https://github.com/rynoV/tower-sessions-surrealdb-store) | Yes | SurrealDB session store | -//! -//! Have a store to add? Please open a PR adding it. -//! -//! ### User session management -//! -//! To facilitate authentication and authorization, we've built [`axum-login`](https://github.com/maxcountryman/axum-login) on top of this crate. Please check it out if you're looking for a generalized auth solution. -//! -//! # Usage with an `axum` application -//! -//! A common use-case for sessions is when building HTTP servers. Using `axum`, -//! it's straightforward to leverage sessions. -//! -//! ```rust,no_run -//! use std::net::SocketAddr; -//! -//! use axum::{response::IntoResponse, routing::get, Router}; -//! use serde::{Deserialize, Serialize}; -//! use time::Duration; -//! use tower_sessions::{Expiry, MemoryStore, Session, SessionManagerLayer}; -//! -//! const COUNTER_KEY: &str = "counter"; -//! -//! #[derive(Default, Deserialize, Serialize)] -//! struct Counter(usize); -//! -//! async fn handler(session: Session) -> impl IntoResponse { -//! let counter: Counter = session.get(COUNTER_KEY).await.unwrap().unwrap_or_default(); -//! session.insert(COUNTER_KEY, counter.0 + 1).await.unwrap(); -//! format!("Current count: {}", counter.0) -//! } -//! -//! #[tokio::main] -//! async fn main() { -//! let session_store = MemoryStore::default(); -//! let session_layer = SessionManagerLayer::new(session_store) -//! .with_secure(false) -//! .with_expiry(Expiry::OnInactivity(Duration::seconds(10))); -//! -//! let app = Router::new().route("/", get(handler)).layer(session_layer); -//! -//! let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); -//! let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); -//! axum::serve(listener, app.into_make_service()) -//! .await -//! .unwrap(); -//! } -//! ``` -//! -//! ## Session expiry management -//! -//! In cases where you are utilizing stores that lack automatic session expiry -//! functionality, such as SQLx or MongoDB stores, it becomes essential to -//! periodically clean up stale sessions. For instance, both SQLx and MongoDB -//! stores offer -//! `continuously_delete_expired` -//! which is designed to be executed as a recurring task. This process ensures -//! the removal of expired sessions, maintaining your application's data -//! integrity and performance. -//! ```rust,no_run,ignore -//! # use tower_sessions::{session_store::ExpiredDeletion}; -//! # use tower_sessions_sqlx_store::{sqlx::SqlitePool, SqliteStore}; -//! # tokio_test::block_on(async { -//! let pool = SqlitePool::connect("sqlite::memory:").await.unwrap(); -//! let session_store = SqliteStore::new(pool); -//! let deletion_task = tokio::task::spawn( -//! session_store -//! .clone() -//! .continuously_delete_expired(tokio::time::Duration::from_secs(60)), -//! ); -//! deletion_task.await.unwrap().unwrap(); -//! # }); -//! ``` -//! -//! Note that by default or when using browser session expiration, sessions are -//! considered expired after two weeks. -//! -//! # Extractor pattern -//! -//! When using `axum`, the [`Session`] will already function as an extractor. -//! It's possible to build further on this to create extractors of custom types. -//! ```rust,no_run -//! # use async_trait::async_trait; -//! # use axum::extract::FromRequestParts; -//! # use http::{request::Parts, StatusCode}; -//! # use serde::{Deserialize, Serialize}; -//! # use tower_sessions::{SessionStore, Session, MemoryStore}; -//! const COUNTER_KEY: &str = "counter"; -//! -//! #[derive(Default, Deserialize, Serialize)] -//! struct Counter(usize); -//! -//! #[async_trait] -//! impl FromRequestParts for Counter -//! where -//! S: Send + Sync, -//! { -//! type Rejection = (http::StatusCode, &'static str); -//! -//! async fn from_request_parts(req: &mut Parts, state: &S) -> Result { -//! let session = Session::from_request_parts(req, state).await?; -//! let counter: Counter = session.get(COUNTER_KEY).await.unwrap().unwrap_or_default(); -//! session.insert(COUNTER_KEY, counter.0 + 1).await.unwrap(); -//! -//! Ok(counter) -//! } -//! } -//! ``` -//! -//! Now in our handler, we can use `Counter` directly to read its fields. -//! -//! A complete example can be found in [`examples/counter-extractor.rs`](https://github.com/maxcountryman/tower-sessions/blob/main/examples/counter-extractor.rs). -//! -//! # Strongly-typed sessions -//! -//! The extractor pattern can be extended further to provide strong typing -//! guarantees over the key-value substrate. Whereas our previous extractor -//! example was effectively read-only. This pattern enables mutability of the -//! underlying structure while also leveraging the full power of the type -//! system. -//! ```rust,no_run -//! # use async_trait::async_trait; -//! # use axum::extract::FromRequestParts; -//! # use http::{request::Parts, StatusCode}; -//! # use serde::{Deserialize, Serialize}; -//! # use time::OffsetDateTime; -//! # use tower_sessions::{SessionStore, Session}; -//! #[derive(Clone, Deserialize, Serialize)] -//! struct GuestData { -//! pageviews: usize, -//! first_seen: OffsetDateTime, -//! last_seen: OffsetDateTime, -//! } -//! -//! impl Default for GuestData { -//! fn default() -> Self { -//! Self { -//! pageviews: 0, -//! first_seen: OffsetDateTime::now_utc(), -//! last_seen: OffsetDateTime::now_utc(), -//! } -//! } -//! } -//! -//! struct Guest { -//! session: Session, -//! guest_data: GuestData, -//! } -//! -//! impl Guest { -//! const GUEST_DATA_KEY: &'static str = "guest_data"; -//! -//! fn first_seen(&self) -> OffsetDateTime { -//! self.guest_data.first_seen -//! } -//! -//! fn last_seen(&self) -> OffsetDateTime { -//! self.guest_data.last_seen -//! } -//! -//! fn pageviews(&self) -> usize { -//! self.guest_data.pageviews -//! } -//! -//! async fn mark_pageview(&mut self) { -//! self.guest_data.pageviews += 1; -//! Self::update_session(&self.session, &self.guest_data).await -//! } -//! -//! async fn update_session(session: &Session, guest_data: &GuestData) { -//! session -//! .insert(Self::GUEST_DATA_KEY, guest_data.clone()) -//! .await -//! .unwrap() -//! } -//! } -//! -//! #[async_trait] -//! impl FromRequestParts for Guest -//! where -//! S: Send + Sync, -//! { -//! type Rejection = (StatusCode, &'static str); -//! -//! async fn from_request_parts(req: &mut Parts, state: &S) -> Result { -//! let session = Session::from_request_parts(req, state).await?; -//! -//! let mut guest_data: GuestData = session -//! .get(Self::GUEST_DATA_KEY) -//! .await -//! .unwrap() -//! .unwrap_or_default(); -//! -//! guest_data.last_seen = OffsetDateTime::now_utc(); -//! -//! Self::update_session(&session, &guest_data).await; -//! -//! Ok(Self { -//! session, -//! guest_data, -//! }) -//! } -//! } -//! ``` -//! -//! Here we can use `Guest` as an extractor in our handler. We'll be able to -//! read values, like the ID as well as update the pageview count with our -//! `mark_pageview` method. -//! -//! A complete example can be found in [`examples/strongly-typed.rs`](https://github.com/maxcountryman/tower-sessions/blob/main/examples/strongly-typed.rs) -//! -//! ## Name-spaced and strongly-typed buckets -//! -//! Our example demonstrates a single extractor, but in a real application we -//! might imagine a set of common extractors, all living in the same session. -//! Each extractor forms a kind of bucketed name-space with a typed structure. -//! Importantly, each is self-contained by its own name-space. -//! -//! For instance, we might also have a site preferences bucket, an analytics -//! bucket, a feature flag bucket and so on. All these together would live in -//! the same session, but would be segmented by their own name-space, avoiding -//! the mixing of domains unnecessarily.[^data-domains] -//! -//! # Layered caching -//! -//! In some cases, the canonical store for a session may benefit from a cache. -//! For example, rather than loading a session from a store on every request, -//! this roundtrip can be mitigated by placing a cache in front of the storage -//! backend. A specialized session store, [`CachingSessionStore`], is provided -//! for exactly this purpose. -//! -//! This store manages a cache and a store. Where the cache acts as a frontend -//! and the store a backend. When a session is loaded, the store first attempts -//! to load the session from the cache, if that fails only then does it try to -//! load from the store. By doing so, read-heavy workloads will incur far fewer -//! roundtrips to the store itself. -//! -//! To illustrate, this is how we might use the -//! `MokaStore` as a frontend cache to a -//! `PostgresStore` backend. -//! ```rust,no_run,ignore -//! # use tower::ServiceBuilder; -//! # use tower_sessions::{CachingSessionStore, SessionManagerLayer}; -//! # use tower_sessions_sqlx_store::{sqlx::PgPool, PostgresStore}; -//! # use tower_sessions_moka_store::MokaStore; -//! # use time::Duration; -//! # tokio_test::block_on(async { -//! let database_url = std::option_env!("DATABASE_URL").unwrap(); -//! let pool = PgPool::connect(database_url).await.unwrap(); -//! -//! let postgres_store = PostgresStore::new(pool); -//! postgres_store.migrate().await.unwrap(); -//! -//! let moka_store = MokaStore::new(Some(10_000)); -//! let caching_store = CachingSessionStore::new(moka_store, postgres_store); -//! -//! let session_service = ServiceBuilder::new() -//! .layer(SessionManagerLayer::new(caching_store).with_max_age(Duration::days(1))); -//! # }) -//! ``` -//! -//! While this example uses Moka, any implementor of [`SessionStore`] may be -//! used. For instance, we could use the `RedisStore` instead of Moka. -//! -//! A cache is most helpful with read-heavy workloads, where the cache hit rate -//! will be high. This is because write-heavy workloads will require a roundtrip -//! to the store and therefore benefit less from caching. -//! -//! ## Data races under concurrent conditions -//! -//! Please note that it is **not safe** to access and mutate session state -//! concurrently: this will result in data loss if your mutations are dependent -//! on the state of the session. -//! -//! This is because a session is loaded first from its backing store. Once -//! loaded it's possible for a second request to load the same session, but -//! without the inflight changes the first request may have made. -//! -//! # Implementation -//! -//! Sessions are composed of three pieces: -//! -//! 1. A cookie that holds the session ID as its value, -//! 2. An in-memory hash-map, which underpins the key-value API, -//! 3. A pluggable persistence layer, the session store, where session data is -//! housed. -//! -//! Together, these pieces form the basis of this crate and allow `tower` and -//! `axum` applications to use a familiar session interface. -//! -//! ## Cookie -//! -//! Sessions manifest to clients as cookies. These cookies have a configurable -//! name and a value that is the session ID. In other words, cookies hold a -//! pointer to the session in the form of an ID. This ID is an i128 generated by -//! the [`rand`](https://docs.rs/rand/latest/rand) crate. -//! -//! ### Secure nature of cookies -//! -//! Session IDs are considered secure if sent over encrypted channels. Note that -//! this assumption is predicated on the secure nature of the [`rand`](https://docs.rs/rand/latest/rand) crate -//! and its ability to generate securely-random values using the ChaCha block -//! cipher with 12 rounds. It's also important to note that session cookies -//! **must never** be sent over a public, insecure channel. Doing so is **not** -//! secure and will lead to compromised sessions! -//! -//! Additionally, sessions may be optionally signed or encrypted by enabling the -//! `signed` and `private` feature flags, respectively. When enabled, the -//! [`with_signed`](SessionManagerLayer::with_signed) and -//! [`with_private`](SessionManagerLayer::with_private) methods become -//! available. These methods take a cryptographic key which allows the session -//! manager to leverage ciphertext as opposed to the default of plaintext. Note -//! that no data is stored in the session ID beyond the session identifier -//! itself and so this measure should be considered primarily effective as a -//! defense in depth tactic. -//! -//! ## Key-value API -//! -//! Sessions manage a `HashMap` but importantly are -//! transparently persisted to an arbitrary storage backend. Effectively, -//! `HashMap` is an intermediary, in-memory representation. By using a map-like -//! structure, we're able to present a familiar key-value interface for managing -//! sessions. This allows us to store and retrieve native Rust types, so long as -//! our type is `impl Serialize` and can be represented as JSON.[^json] -//! -//! Internally, this hash map state is protected by a lock in the form of -//! `Mutex`. This allows us to safely share mutable state across thread -//! boundaries. Note that this lock is only acquired when we read from or write -//! to this inner session state and not used when the session is provided to the -//! request. This means that lock contention is minimized for most use -//! cases.[^lock-contention] -//! -//! ## Session store -//! -//! Sessions are serialized to arbitrary storage backends via a session record -//! intermediary. Implementations of `SessionStore` take a record and persist -//! it such that it can later be loaded via the session ID. -//! -//! Three components are needed for storing a session: -//! -//! 1. The session ID. -//! 2. The session expiry. -//! 3. The session data itself. -//! -//! Together, these compose the session record and are enough to both encode and -//! decode a session from any backend. -//! -//! ## Session life cycle -//! -//! Cookies hold a pointer to the session, rather than the session's data, and -//! because of this, the `tower` middleware is focused on managing the process -//! of initializing a session which can later be used in code to transparently -//! interact with the store. -//! -//! A session is initialized by looking for a cookie that matches the configured -//! session cookie name. If no such cookie is found or a cookie is found but is -//! malformed, an empty session is initialized. -//! -//! Modified sessions will invoke the session's [`save`](Session::save) method -//! as well as append to the `Set-Cookie` header of the response. -//! -//! Empty sessions are considered deleted and will set a removal cookie -//! on the response but are not removed from the store directly. -//! -//! Sessions also carry with them a configurable expiry and will be removed in -//! accordance with this. -//! -//! Notably, the session life cycle minimizes overhead with the store. All -//! session store methods are deferred until the point [`Session`] is used in -//! code and more specifically one of its methods requiring the store is called. -//! -//! [^json]: Using JSON allows us to translate arbitrary types to virtually -//! any backend and gives us a nice interface with which to interact with the -//! session. -//! -//! [^lock-contention]: We might consider replacing `Mutex` with `RwLock` if -//! this proves to be a better fit in practice. Another alternative might be -//! `dashmap` or a different approach entirely. Future iterations should be -//! based on real-world use cases. -//! -//! [^data-domains]: This is particularly useful when we may have data -//! domains that only belong with ! users in certain states: we can pull these -//! into our handlers where we need a particular domain. In this way, we -//! minimize data pollution via self-contained domains in the form of buckets. +#![doc = include_str!("../README.md")] + #![warn( clippy::all, nonstandard_style, @@ -434,21 +7,22 @@ missing_debug_implementations )] #![deny(missing_docs)] -#![forbid(unsafe_code)] #![cfg_attr(docsrs, feature(doc_cfg))] -pub use tower_cookies::cookie; -pub use tower_sessions_core::{session, session_store}; +pub use tower_sesh_core::session_store; #[doc(inline)] -pub use tower_sessions_core::{ - session::{Expiry, Session}, - session_store::{CachingSessionStore, ExpiredDeletion, SessionStore}, +pub use tower_sesh_core::{ + id::Id, + expires::{Expires, Expiry}, + session_store::{CachingSessionStore, SessionStore}, }; #[cfg(feature = "memory-store")] #[cfg_attr(docsrs, doc(cfg(feature = "memory-store")))] #[doc(inline)] -pub use tower_sessions_memory_store::MemoryStore; +pub use tower_sesh_memory_store::MemoryStore; -pub use crate::service::{SessionManager, SessionManagerLayer}; +pub use crate::middleware::{SessionManager, SessionManagerLayer}; +pub use crate::session::{Session, SessionState}; -pub mod service; +pub mod middleware; +pub mod session; diff --git a/src/middleware.rs b/src/middleware.rs new file mode 100644 index 0000000..1ffa57f --- /dev/null +++ b/src/middleware.rs @@ -0,0 +1,504 @@ +//! A middleware that provides [`Session`] as a request extension. +use std::{ + future::Future, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, +}; + +use cookie::{Cookie, SameSite}; +use http::{header::COOKIE, Request, Response}; +use pin_project_lite::pin_project; +use time::OffsetDateTime; +use tower_layer::Layer; +use tower_service::Service; +use tower_sesh_core::{expires::Expiry, id::Id}; +use tracing::{instrument::Instrumented, Instrument}; + +use crate::{ + session::{SessionUpdate, Updater}, + Session, +}; + +/// the configuration options for the [`SessionManagerLayer`]. +/// +/// ## Default +/// ``` +/// # use tower_sesh::middleware::Config; +/// # use tower_sesh::Expiry; +/// # use cookie::SameSite; +/// let default = Config { +/// name: "id", +/// http_only: true, +/// same_site: SameSite::Strict, +/// secure: true, +/// path: "/", +/// domain: None, +/// always_set_expiry: None, +/// }; +/// +/// assert_eq!(default, Config::default()); +/// ``` +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub struct Config<'a> { + /// The name of the cookie. + pub name: &'a str, + /// Whether the cookie is [HTTP only](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie#httponly). + pub http_only: bool, + /// The + /// [SameSite](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie#samesitesamesite-value) + /// policy. + pub same_site: SameSite, + /// Whether the cookie should be + /// [secure](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie#secure). + pub secure: bool, + /// The [path](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie#pathpath-value) + /// attribute of the cookie. + pub path: &'a str, + /// The + /// [domain](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie#domaindomain-value) + /// attribute of the cookie. + pub domain: Option<&'a str>, + /// If this is set to `None`, the session will only be saved if it is modified. If it is set to + /// `Some(expiry)`, the session will be saved as usual if it is modified, but it will also be + /// saved with the provided `expiry` when it is not modified. + /// + /// This manages the + /// [`Max-Age`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie#max-agenumber) + /// and the + /// [`Expires`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie#expiresdate) + /// attributes. + pub always_set_expiry: Option, +} + +impl<'a> Config<'a> { + fn build_cookie(self, session_id: Option, expiry: Expiry) -> Cookie<'a> { + let mut cookie_builder = Cookie::build(( + self.name, + session_id + .as_ref() + .map(ToString::to_string) + .unwrap_or_default(), + )) + .http_only(self.http_only) + .same_site(self.same_site) + .secure(self.secure) + .path(self.path); + + cookie_builder = match expiry { + Expiry::OnInactivity(duration) => cookie_builder.max_age(duration), + Expiry::AtDateTime(datetime) => { + cookie_builder.max_age(datetime - OffsetDateTime::now_utc()) + } + Expiry::OnSessionEnd => cookie_builder, + }; + + if let Some(domain) = self.domain { + cookie_builder = cookie_builder.domain(domain); + } + + cookie_builder.build() + } +} + +impl Default for Config<'static> { + fn default() -> Self { + Self { + name: "id", /* See: https://cheatsheetseries.owasp.org/cheatsheets/Session_Management_Cheat_Sheet.html#session-id-name-fingerprinting */ + http_only: true, + same_site: SameSite::Strict, + secure: true, + path: "/", + domain: None, + always_set_expiry: None, + } + } +} + +/// A middleware that provides [`Session`] as a request extension. +#[derive(Debug, Clone)] +pub struct SessionManager { + inner: S, + store: Store, + config: Config<'static>, +} + +impl SessionManager { + /// Create a new [`SessionManager`]. + /// + /// # Examples + /// ``` + /// use tower_sesh::{MemoryStore, SessionManager}; + /// + /// struct MyService; + /// + /// let _ = SessionManager::new(MyService, MemoryStore::<()>::default(), Default::default()); + /// ``` + pub fn new(inner: S, store: Store, config: Config<'static>) -> Self { + Self { + inner, + store, + config, + } + } +} + +impl Service> for SessionManager +where + S: Service, Response = Response>, + Store: Clone + Send + Sync + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = Instrumented>; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Request) -> Self::Future { + let span = tracing::debug_span!("session_manager"); + let _enter = span.enter(); + + let session_cookie = req + .headers() + .get_all(COOKIE) + .into_iter() + .filter_map(|value| value.to_str().ok()) + .flat_map(|value| value.split(';')) + .filter_map(|cookie| Cookie::parse(cookie).ok()) + .find(|cookie| cookie.name() == self.config.name); + + let id = session_cookie.and_then(|cookie| { + cookie + .value() + .parse::() + .map_err(|err| { + tracing::warn!( + err = %err, + "possibly suspicious activity: malformed session id" + ) + }) + .ok() + }); + + let updater = Arc::new(Mutex::new(None)); + let session = Session { + id, + store: self.store.clone(), + updater: Arc::clone(&updater), + }; + tracing::debug!("adding session to request extensions"); + req.extensions_mut().insert(session); + + drop(_enter); + ResponseFuture { + inner: self.inner.call(req), + updater, + config: self.config, + old_id: id, + } + .instrument(span) + } +} + +pin_project! { + #[derive(Debug, Clone)] + /// The future returned by [`SessionManager`]. + pub struct ResponseFuture { + #[pin] + inner: F, + updater: Updater, + config: Config<'static>, + old_id: Option, + } +} + +impl Future for ResponseFuture +where + F: Future, Error>>, +{ + type Output = Result, Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let self_ = self.project(); + let mut resp = match self_.inner.poll(cx) { + Poll::Ready(r) => r, + Poll::Pending => return Poll::Pending, + }?; + + let update = self_ + .updater + .lock() + .expect("updater should not be poisoned") + .or_else(|| { + self_ + .config + .always_set_expiry + .and_then(|expiry| self_.old_id.map(|id| SessionUpdate::Set(id, expiry))) + }); + match update { + Some(SessionUpdate::Delete) => { + tracing::debug!("deleting session"); + let cookie = self_.config.build_cookie( + *self_.old_id, + Expiry::AtDateTime( + // The Year 2000 in UNIX time. + time::OffsetDateTime::from_unix_timestamp(946684800) + .expect("year 2000 should be in range"), + ), + ); + resp.headers_mut().insert( + http::header::SET_COOKIE, + cookie + .to_string() + .try_into() + .expect("cookie should be valid"), + ); + } + Some(SessionUpdate::Set(id, expiry)) => { + tracing::debug!("setting session {id}, expiring: {:?}", expiry); + let cookie = self_.config.build_cookie(Some(id), expiry); + resp.headers_mut().insert( + http::header::SET_COOKIE, + cookie + .to_string() + .try_into() + .expect("cookie should be valid"), + ); + } + None => {} + }; + + Poll::Ready(Ok(resp)) + } +} + +/// A layer for providing [`Session`] as a request extension. +/// +/// # Examples +/// +/// ```rust +/// use tower_sesh::{MemoryStore, SessionManagerLayer}; +/// +/// let session_store: MemoryStore<()> = MemoryStore::default(); +/// let session_service = SessionManagerLayer { +/// store: session_store, +/// config: Default::default() +/// }; +/// ``` +#[derive(Debug, Clone)] +pub struct SessionManagerLayer { + /// The store to use for session data. + /// + /// This should implement [`tower_sesh_core::SessionStore`], and be cloneable. + pub store: Store, + /// The configuration options for the session cookie. + pub config: Config<'static>, +} + +impl Layer for SessionManagerLayer +where + Store: Clone, +{ + type Service = SessionManager; + + fn layer(&self, inner: S) -> Self::Service { + SessionManager { + inner, + store: self.store.clone(), + config: self.config, + } + } +} + +#[cfg(test)] +mod tests { + use anyhow::anyhow; + use axum::body::Body; + use tower::{ServiceBuilder, ServiceExt}; + use tower_sesh_core::Expires; + use tower_sesh_memory_store::MemoryStore; + + use super::*; + + #[derive(Debug, Clone)] + struct Record { + foo: i32, + } + impl Expires for Record {} + + async fn handler(mut req: Request) -> anyhow::Result> { + let session = req + .extensions_mut() + .remove::>>() + .ok_or(anyhow!("Missing session"))?; + + let session_state = session.clone().load().await?; + if let Some(session_state) = session_state { + session_state + .update(|data| { + data.foo += 1; + }) + .await?; + } else { + session.create(Record { foo: 42 }).await?; + } + + Ok(Response::new(Body::empty())) + } + + async fn noop_handler(_: Request) -> anyhow::Result> { + Ok(Response::new(Body::empty())) + } + + #[tokio::test] + async fn basic_service_test() -> anyhow::Result<()> { + let session_store: MemoryStore = MemoryStore::default(); + let session_layer = SessionManagerLayer { + store: session_store, + config: Default::default(), + }; + let svc = ServiceBuilder::new() + .layer(session_layer.clone()) + .service_fn(handler); + + let noop_svc = ServiceBuilder::new() + .layer(session_layer) + .service_fn(noop_handler); + + let req = Request::builder().body(Body::empty())?; + let res = svc.clone().oneshot(req).await?; + + let session = res.headers().get(http::header::SET_COOKIE); + assert!(session.is_some()); + + let req = Request::builder() + .header(http::header::COOKIE, session.unwrap()) + .body(Body::empty())?; + let res = noop_svc.oneshot(req).await?; + + assert!(res.headers().get(http::header::SET_COOKIE).is_none()); + + Ok(()) + } + + #[tokio::test] + async fn bogus_cookie_test() -> anyhow::Result<()> { + let session_store: MemoryStore = MemoryStore::default(); + let session_layer = SessionManagerLayer { + store: session_store, + config: Default::default(), + }; + let svc = ServiceBuilder::new() + .layer(session_layer) + .service_fn(handler); + + let req = Request::builder().body(Body::empty())?; + let res = svc.clone().oneshot(req).await?; + + assert!(res.headers().get(http::header::SET_COOKIE).is_some()); + + let req = Request::builder() + .header(http::header::COOKIE, "id=bogus") + .body(Body::empty())?; + let res = svc.oneshot(req).await?; + + assert!(res.headers().get(http::header::SET_COOKIE).is_some()); + + Ok(()) + } + + #[tokio::test] + async fn no_set_cookie_test() -> anyhow::Result<()> { + let session_store: MemoryStore = MemoryStore::default(); + let session_layer = SessionManagerLayer { + store: session_store, + config: Default::default(), + }; + let svc = ServiceBuilder::new() + .layer(session_layer) + .service_fn(noop_handler); + + let req = Request::builder().body(Body::empty())?; + let res = svc.oneshot(req).await?; + + assert!(res.headers().get(http::header::SET_COOKIE).is_none()); + + Ok(()) + } + + #[tokio::test] + async fn custom_config() -> anyhow::Result<()> { + let session_store: MemoryStore = MemoryStore::default(); + + let session_config = Config { + name: "my.sid", + http_only: false, + same_site: SameSite::Lax, + secure: false, + path: "/foo/bar", + domain: Some("example.com"), + always_set_expiry: Some(Expiry::OnInactivity(time::Duration::hours(2))), + }; + let session_layer = SessionManagerLayer { + store: session_store, + config: session_config, + }; + let svc = ServiceBuilder::new() + .layer(session_layer.clone()) + .service_fn(handler); + let noop_svc = ServiceBuilder::new() + .layer(session_layer) + .service_fn(noop_handler); + + let req = Request::builder().body(Body::empty())?; + let res = svc.oneshot(req).await?; + + assert!(cookie_value_matches(&res, |s| s.contains("my.sid="))); + assert!(cookie_value_matches(&res, |s| s.contains("SameSite=Lax"))); + assert!(cookie_value_matches(&res, |s| !s.contains("Secure"))); + assert!(cookie_value_matches(&res, |s| s.contains("Path=/foo/bar"))); + assert!(cookie_value_matches(&res, |s| s.contains("Domain=example.com"))); + + let req = Request::builder() + .header( + http::header::COOKIE, + res.headers().get(http::header::SET_COOKIE).unwrap(), + ) + .body(Body::empty())?; + let res = noop_svc.oneshot(req).await?; + assert!(cookie_has_expected_max_age(&res, 7200)); + + Ok(()) + } + + fn cookie_value_matches(res: &Response, matcher: F) -> bool + where + F: FnOnce(&str) -> bool, + { + res.headers() + .get(http::header::SET_COOKIE) + .is_some_and(|set_cookie| set_cookie.to_str().is_ok_and(matcher)) + } + + fn cookie_has_expected_max_age(res: &Response, expected_value: i64) -> bool { + res.headers() + .get(http::header::SET_COOKIE) + .is_some_and(|set_cookie| { + set_cookie.to_str().is_ok_and(|s| { + let max_age_value = s + .split("Max-Age=") + .nth(1) + .unwrap_or_default() + .split(';') + .next() + .unwrap_or_default() + .parse::() + .unwrap_or_default(); + (max_age_value - expected_value).abs() <= 1 + }) + }) + } +} diff --git a/src/service.rs b/src/service.rs deleted file mode 100644 index 2c573f8..0000000 --- a/src/service.rs +++ /dev/null @@ -1,1012 +0,0 @@ -//! A middleware that provides [`Session`] as a request extension. -use std::{ - borrow::Cow, - future::Future, - pin::Pin, - sync::Arc, - task::{Context, Poll}, -}; - -use http::{Request, Response}; -use time::OffsetDateTime; -#[cfg(any(feature = "signed", feature = "private"))] -use tower_cookies::Key; -use tower_cookies::{cookie::SameSite, Cookie, CookieManager, Cookies}; -use tower_layer::Layer; -use tower_service::Service; -use tracing::Instrument; - -use crate::{ - session::{self, Expiry}, - Session, SessionStore, -}; - -#[doc(hidden)] -pub trait CookieController: Clone + Send + 'static { - fn get(&self, cookies: &Cookies, name: &str) -> Option>; - fn add(&self, cookies: &Cookies, cookie: Cookie<'static>); - fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>); -} - -#[doc(hidden)] -#[derive(Debug, Clone)] -pub struct PlaintextCookie; - -impl CookieController for PlaintextCookie { - fn get(&self, cookies: &Cookies, name: &str) -> Option> { - cookies.get(name).map(Cookie::into_owned) - } - - fn add(&self, cookies: &Cookies, cookie: Cookie<'static>) { - cookies.add(cookie) - } - - fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>) { - cookies.remove(cookie) - } -} - -#[doc(hidden)] -#[cfg(feature = "signed")] -#[derive(Debug, Clone)] -pub struct SignedCookie { - key: Key, -} - -#[cfg(feature = "signed")] -impl CookieController for SignedCookie { - fn get(&self, cookies: &Cookies, name: &str) -> Option> { - cookies.signed(&self.key).get(name).map(Cookie::into_owned) - } - - fn add(&self, cookies: &Cookies, cookie: Cookie<'static>) { - cookies.signed(&self.key).add(cookie) - } - - fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>) { - cookies.signed(&self.key).remove(cookie) - } -} - -#[doc(hidden)] -#[cfg(feature = "private")] -#[derive(Debug, Clone)] -pub struct PrivateCookie { - key: Key, -} - -#[cfg(feature = "private")] -impl CookieController for PrivateCookie { - fn get(&self, cookies: &Cookies, name: &str) -> Option> { - cookies.private(&self.key).get(name).map(Cookie::into_owned) - } - - fn add(&self, cookies: &Cookies, cookie: Cookie<'static>) { - cookies.private(&self.key).add(cookie) - } - - fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>) { - cookies.private(&self.key).remove(cookie) - } -} - -#[derive(Debug, Clone)] -struct SessionConfig<'a> { - name: Cow<'a, str>, - http_only: bool, - same_site: SameSite, - expiry: Option, - secure: bool, - path: Cow<'a, str>, - domain: Option>, - always_save: bool, -} - -impl<'a> SessionConfig<'a> { - fn build_cookie(self, session_id: session::Id, expiry: Option) -> Cookie<'a> { - let mut cookie_builder = Cookie::build((self.name, session_id.to_string())) - .http_only(self.http_only) - .same_site(self.same_site) - .secure(self.secure) - .path(self.path); - - cookie_builder = match expiry { - Some(Expiry::OnInactivity(duration)) => cookie_builder.max_age(duration), - Some(Expiry::AtDateTime(datetime)) => { - cookie_builder.max_age(datetime - OffsetDateTime::now_utc()) - } - Some(Expiry::OnSessionEnd) | None => cookie_builder, - }; - - if let Some(domain) = self.domain { - cookie_builder = cookie_builder.domain(domain); - } - - cookie_builder.build() - } -} - -impl<'a> Default for SessionConfig<'a> { - fn default() -> Self { - Self { - name: "id".into(), /* See: https://cheatsheetseries.owasp.org/cheatsheets/Session_Management_Cheat_Sheet.html#session-id-name-fingerprinting */ - http_only: true, - same_site: SameSite::Strict, - expiry: None, // TODO: Is `Max-Age: "Session"` the right default? - secure: true, - path: "/".into(), - domain: None, - always_save: false, - } - } -} - -/// A middleware that provides [`Session`] as a request extension. -#[derive(Debug, Clone)] -pub struct SessionManager { - inner: S, - session_store: Arc, - session_config: SessionConfig<'static>, - cookie_controller: C, -} - -impl SessionManager { - /// Create a new [`SessionManager`]. - pub fn new(inner: S, session_store: Store) -> Self { - Self { - inner, - session_store: Arc::new(session_store), - session_config: Default::default(), - cookie_controller: PlaintextCookie, - } - } -} - -impl Service> - for SessionManager -where - S: Service, Response = Response> + Clone + Send + 'static, - S::Future: Send, - ReqBody: Send + 'static, - ResBody: Default + Send, -{ - type Response = S::Response; - type Error = S::Error; - type Future = Pin> + Send>>; - - #[inline] - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_ready(cx) - } - - fn call(&mut self, mut req: Request) -> Self::Future { - let span = tracing::info_span!("call"); - - let session_store = self.session_store.clone(); - let session_config = self.session_config.clone(); - let cookie_controller = self.cookie_controller.clone(); - - // Because the inner service can panic until ready, we need to ensure we only - // use the ready service. - // - // See: https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services - let clone = self.inner.clone(); - let mut inner = std::mem::replace(&mut self.inner, clone); - - Box::pin( - async move { - let Some(cookies) = req.extensions().get::<_>().cloned() else { - // In practice this should never happen because we wrap `CookieManager` - // directly. - tracing::error!("missing cookies request extension"); - return Ok(Response::default()); - }; - - let session_cookie = cookie_controller.get(&cookies, &session_config.name); - let session_id = session_cookie.as_ref().and_then(|cookie| { - cookie - .value() - .parse::() - .map_err(|err| { - tracing::warn!( - err = %err, - "possibly suspicious activity: malformed session id" - ) - }) - .ok() - }); - - let session = Session::new(session_id, session_store, session_config.expiry); - - req.extensions_mut().insert(session.clone()); - - let res = inner.call(req).await?; - - let modified = session.is_modified(); - let empty = session.is_empty().await; - - tracing::trace!( - modified = modified, - empty = empty, - always_save = session_config.always_save, - "session response state", - ); - - match session_cookie { - Some(mut cookie) if empty => { - tracing::debug!("removing session cookie"); - - // Path and domain must be manually set to ensure a proper removal cookie is - // constructed. - // - // See: https://docs.rs/cookie/latest/cookie/struct.CookieJar.html#method.remove - cookie.set_path(session_config.path); - if let Some(domain) = session_config.domain { - cookie.set_domain(domain); - } - - cookie_controller.remove(&cookies, cookie); - } - - _ if (modified || session_config.always_save) - && !empty - && !res.status().is_server_error() => - { - tracing::debug!("saving session"); - if let Err(err) = session.save().await { - tracing::error!(err = %err, "failed to save session"); - - let mut res = Response::default(); - *res.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR; - return Ok(res); - } - - let Some(session_id) = session.id() else { - tracing::error!("missing session id"); - - let mut res = Response::default(); - *res.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR; - return Ok(res); - }; - - let expiry = session.expiry(); - let session_cookie = session_config.build_cookie(session_id, expiry); - - tracing::debug!("adding session cookie"); - cookie_controller.add(&cookies, session_cookie); - } - - _ => (), - }; - - Ok(res) - } - .instrument(span), - ) - } -} - -/// A layer for providing [`Session`] as a request extension. -#[derive(Debug, Clone)] -pub struct SessionManagerLayer { - session_store: Arc, - session_config: SessionConfig<'static>, - cookie_controller: C, -} - -impl SessionManagerLayer { - /// Configures the name of the cookie used for the session. - /// The default value is `"id"`. - /// - /// # Examples - /// - /// ```rust - /// use tower_sessions::{MemoryStore, SessionManagerLayer}; - /// - /// let session_store = MemoryStore::default(); - /// let session_service = SessionManagerLayer::new(session_store).with_name("my.sid"); - /// ``` - pub fn with_name>>(mut self, name: N) -> Self { - self.session_config.name = name.into(); - self - } - - /// Configures the `"HttpOnly"` attribute of the cookie used for the - /// session. - /// - /// # ⚠️ **Warning: Cross-site scripting risk** - /// - /// Applications should generally **not** override the default value of - /// `true`. If you do, you are exposing your application to increased risk - /// of cookie theft via techniques like cross-site scripting. - /// - /// # Examples - /// - /// ```rust - /// use tower_sessions::{MemoryStore, SessionManagerLayer}; - /// - /// let session_store = MemoryStore::default(); - /// let session_service = SessionManagerLayer::new(session_store).with_http_only(true); - /// ``` - pub fn with_http_only(mut self, http_only: bool) -> Self { - self.session_config.http_only = http_only; - self - } - - /// Configures the `"SameSite"` attribute of the cookie used for the - /// session. - /// The default value is [`SameSite::Strict`]. - /// - /// # Examples - /// - /// ```rust - /// use tower_sessions::{cookie::SameSite, MemoryStore, SessionManagerLayer}; - /// - /// let session_store = MemoryStore::default(); - /// let session_service = SessionManagerLayer::new(session_store).with_same_site(SameSite::Lax); - /// ``` - pub fn with_same_site(mut self, same_site: SameSite) -> Self { - self.session_config.same_site = same_site; - self - } - - /// Configures the `"Max-Age"` attribute of the cookie used for the session. - /// The default value is `None`. - /// - /// # Examples - /// - /// ```rust - /// use time::Duration; - /// use tower_sessions::{Expiry, MemoryStore, SessionManagerLayer}; - /// - /// let session_store = MemoryStore::default(); - /// let session_expiry = Expiry::OnInactivity(Duration::hours(1)); - /// let session_service = SessionManagerLayer::new(session_store).with_expiry(session_expiry); - /// ``` - pub fn with_expiry(mut self, expiry: Expiry) -> Self { - self.session_config.expiry = Some(expiry); - self - } - - /// Configures the `"Secure"` attribute of the cookie used for the session. - /// The default value is `true`. - /// - /// # Examples - /// - /// ```rust - /// use tower_sessions::{MemoryStore, SessionManagerLayer}; - /// - /// let session_store = MemoryStore::default(); - /// let session_service = SessionManagerLayer::new(session_store).with_secure(true); - /// ``` - pub fn with_secure(mut self, secure: bool) -> Self { - self.session_config.secure = secure; - self - } - - /// Configures the `"Path"` attribute of the cookie used for the session. - /// The default value is `"/"`. - /// - /// # Examples - /// - /// ```rust - /// use tower_sessions::{MemoryStore, SessionManagerLayer}; - /// - /// let session_store = MemoryStore::default(); - /// let session_service = SessionManagerLayer::new(session_store).with_path("/some/path"); - /// ``` - pub fn with_path>>(mut self, path: P) -> Self { - self.session_config.path = path.into(); - self - } - - /// Configures the `"Domain"` attribute of the cookie used for the session. - /// The default value is `None`. - /// - /// # Examples - /// - /// ```rust - /// use tower_sessions::{MemoryStore, SessionManagerLayer}; - /// - /// let session_store = MemoryStore::default(); - /// let session_service = SessionManagerLayer::new(session_store).with_domain("localhost"); - /// ``` - pub fn with_domain>>(mut self, domain: D) -> Self { - self.session_config.domain = Some(domain.into()); - self - } - - /// Configures whether unmodified session should be saved on read or not. - /// When the value is `true`, the session will be saved even if it was not - /// changed. - /// - /// This is useful when you want to reset [`Session`] expiration time - /// on any valid request at the cost of higher [`SessionStore`] write - /// activity and transmitting `set-cookie` header with each response. - /// - /// It makes sense to use this setting with relative session expiration - /// values, such as `Expiry::OnInactivity(Duration)`. This setting will - /// _not_ cause session id to be cycled on save. - /// - /// The default value is `false`. - /// - /// # Examples - /// - /// ```rust - /// use time::Duration; - /// use tower_sessions::{Expiry, MemoryStore, SessionManagerLayer}; - /// - /// let session_store = MemoryStore::default(); - /// let session_expiry = Expiry::OnInactivity(Duration::hours(1)); - /// let session_service = SessionManagerLayer::new(session_store) - /// .with_expiry(session_expiry) - /// .with_always_save(true); - /// ``` - pub fn with_always_save(mut self, always_save: bool) -> Self { - self.session_config.always_save = always_save; - self - } - - /// Manages the session cookie via a signed interface. - /// - /// See [`SignedCookies`](tower_cookies::SignedCookies). - /// - /// ```rust - /// use tower_sessions::{cookie::Key, MemoryStore, SessionManagerLayer}; - /// - /// # /* - /// let key = { /* a cryptographically random key >= 64 bytes */ }; - /// # */ - /// # let key: &Vec = &(0..64).collect(); - /// # let key: &[u8] = &key[..]; - /// # let key = Key::try_from(key).unwrap(); - /// - /// let session_store = MemoryStore::default(); - /// let session_service = SessionManagerLayer::new(session_store).with_signed(key); - /// ``` - #[cfg(feature = "signed")] - pub fn with_signed(self, key: Key) -> SessionManagerLayer { - SessionManagerLayer:: { - session_store: self.session_store, - session_config: self.session_config, - cookie_controller: SignedCookie { key }, - } - } - - /// Manages the session cookie via an encrypted interface. - /// - /// See [`PrivateCookies`](tower_cookies::PrivateCookies). - /// - /// ```rust - /// use tower_sessions::{cookie::Key, MemoryStore, SessionManagerLayer}; - /// - /// # /* - /// let key = { /* a cryptographically random key >= 64 bytes */ }; - /// # */ - /// # let key: &Vec = &(0..64).collect(); - /// # let key: &[u8] = &key[..]; - /// # let key = Key::try_from(key).unwrap(); - /// - /// let session_store = MemoryStore::default(); - /// let session_service = SessionManagerLayer::new(session_store).with_private(key); - /// ``` - #[cfg(feature = "private")] - pub fn with_private(self, key: Key) -> SessionManagerLayer { - SessionManagerLayer:: { - session_store: self.session_store, - session_config: self.session_config, - cookie_controller: PrivateCookie { key }, - } - } -} - -impl SessionManagerLayer { - /// Create a new [`SessionManagerLayer`] with the provided session store - /// and default cookie configuration. - /// - /// # Examples - /// - /// ```rust - /// use tower_sessions::{MemoryStore, SessionManagerLayer}; - /// - /// let session_store = MemoryStore::default(); - /// let session_service = SessionManagerLayer::new(session_store); - /// ``` - pub fn new(session_store: Store) -> Self { - let session_config = SessionConfig::default(); - - Self { - session_store: Arc::new(session_store), - session_config, - cookie_controller: PlaintextCookie, - } - } -} - -impl Layer for SessionManagerLayer { - type Service = CookieManager>; - - fn layer(&self, inner: S) -> Self::Service { - let session_manager = SessionManager { - inner, - session_store: self.session_store.clone(), - session_config: self.session_config.clone(), - cookie_controller: self.cookie_controller.clone(), - }; - - CookieManager::new(session_manager) - } -} - -#[cfg(test)] -mod tests { - use std::str::FromStr; - - use anyhow::anyhow; - use axum::body::Body; - use tower::{ServiceBuilder, ServiceExt}; - use tower_sessions_memory_store::MemoryStore; - - use crate::session::{Id, Record}; - - use super::*; - - async fn handler(req: Request) -> anyhow::Result> { - let session = req - .extensions() - .get::() - .ok_or(anyhow!("Missing session"))?; - - session.insert("foo", 42).await?; - - Ok(Response::new(Body::empty())) - } - - async fn noop_handler(_: Request) -> anyhow::Result> { - Ok(Response::new(Body::empty())) - } - - #[tokio::test] - async fn basic_service_test() -> anyhow::Result<()> { - let session_store = MemoryStore::default(); - let session_layer = SessionManagerLayer::new(session_store); - let svc = ServiceBuilder::new() - .layer(session_layer) - .service_fn(handler); - - let req = Request::builder().body(Body::empty())?; - let res = svc.clone().oneshot(req).await?; - - let session = res.headers().get(http::header::SET_COOKIE); - assert!(session.is_some()); - - let req = Request::builder() - .header(http::header::COOKIE, session.unwrap()) - .body(Body::empty())?; - let res = svc.oneshot(req).await?; - - assert!(res.headers().get(http::header::SET_COOKIE).is_none()); - - Ok(()) - } - - #[tokio::test] - async fn bogus_cookie_test() -> anyhow::Result<()> { - let session_store = MemoryStore::default(); - let session_layer = SessionManagerLayer::new(session_store); - let svc = ServiceBuilder::new() - .layer(session_layer) - .service_fn(handler); - - let req = Request::builder().body(Body::empty())?; - let res = svc.clone().oneshot(req).await?; - - assert!(res.headers().get(http::header::SET_COOKIE).is_some()); - - let req = Request::builder() - .header(http::header::COOKIE, "id=bogus") - .body(Body::empty())?; - let res = svc.oneshot(req).await?; - - assert!(res.headers().get(http::header::SET_COOKIE).is_some()); - - Ok(()) - } - - #[tokio::test] - async fn no_set_cookie_test() -> anyhow::Result<()> { - let session_store = MemoryStore::default(); - let session_layer = SessionManagerLayer::new(session_store); - let svc = ServiceBuilder::new() - .layer(session_layer) - .service_fn(noop_handler); - - let req = Request::builder().body(Body::empty())?; - let res = svc.oneshot(req).await?; - - assert!(res.headers().get(http::header::SET_COOKIE).is_none()); - - Ok(()) - } - - #[tokio::test] - async fn name_test() -> anyhow::Result<()> { - let session_store = MemoryStore::default(); - let session_layer = SessionManagerLayer::new(session_store).with_name("my.sid"); - let svc = ServiceBuilder::new() - .layer(session_layer) - .service_fn(handler); - - let req = Request::builder().body(Body::empty())?; - let res = svc.oneshot(req).await?; - - assert!(cookie_value_matches(&res, |s| s.starts_with("my.sid="))); - - Ok(()) - } - - #[tokio::test] - async fn http_only_test() -> anyhow::Result<()> { - let session_store = MemoryStore::default(); - let session_layer = SessionManagerLayer::new(session_store); - let svc = ServiceBuilder::new() - .layer(session_layer) - .service_fn(handler); - - let req = Request::builder().body(Body::empty())?; - let res = svc.oneshot(req).await?; - - assert!(cookie_value_matches(&res, |s| s.contains("HttpOnly"))); - - let session_store = MemoryStore::default(); - let session_layer = SessionManagerLayer::new(session_store).with_http_only(false); - let svc = ServiceBuilder::new() - .layer(session_layer) - .service_fn(handler); - - let req = Request::builder().body(Body::empty())?; - let res = svc.oneshot(req).await?; - - assert!(cookie_value_matches(&res, |s| !s.contains("HttpOnly"))); - - Ok(()) - } - - #[tokio::test] - async fn same_site_strict_test() -> anyhow::Result<()> { - let session_store = MemoryStore::default(); - let session_layer = - SessionManagerLayer::new(session_store).with_same_site(SameSite::Strict); - let svc = ServiceBuilder::new() - .layer(session_layer) - .service_fn(handler); - - let req = Request::builder().body(Body::empty())?; - let res = svc.oneshot(req).await?; - - assert!(cookie_value_matches(&res, |s| s.contains("SameSite=Strict"))); - - Ok(()) - } - - #[tokio::test] - async fn same_site_lax_test() -> anyhow::Result<()> { - let session_store = MemoryStore::default(); - let session_layer = SessionManagerLayer::new(session_store).with_same_site(SameSite::Lax); - let svc = ServiceBuilder::new() - .layer(session_layer) - .service_fn(handler); - - let req = Request::builder().body(Body::empty())?; - let res = svc.oneshot(req).await?; - - assert!(cookie_value_matches(&res, |s| s.contains("SameSite=Lax"))); - - Ok(()) - } - - #[tokio::test] - async fn same_site_none_test() -> anyhow::Result<()> { - let session_store = MemoryStore::default(); - let session_layer = SessionManagerLayer::new(session_store).with_same_site(SameSite::None); - let svc = ServiceBuilder::new() - .layer(session_layer) - .service_fn(handler); - - let req = Request::builder().body(Body::empty())?; - let res = svc.oneshot(req).await?; - - assert!(cookie_value_matches(&res, |s| s.contains("SameSite=None"))); - - Ok(()) - } - - #[tokio::test] - async fn expiry_on_session_end_test() -> anyhow::Result<()> { - let session_store = MemoryStore::default(); - let session_layer = - SessionManagerLayer::new(session_store).with_expiry(Expiry::OnSessionEnd); - let svc = ServiceBuilder::new() - .layer(session_layer) - .service_fn(handler); - - let req = Request::builder().body(Body::empty())?; - let res = svc.oneshot(req).await?; - - assert!(cookie_value_matches(&res, |s| !s.contains("Max-Age"))); - - Ok(()) - } - - #[tokio::test] - async fn expiry_on_inactivity_test() -> anyhow::Result<()> { - let session_store = MemoryStore::default(); - let inactivity_duration = time::Duration::hours(2); - let session_layer = SessionManagerLayer::new(session_store) - .with_expiry(Expiry::OnInactivity(inactivity_duration)); - let svc = ServiceBuilder::new() - .layer(session_layer) - .service_fn(handler); - - let req = Request::builder().body(Body::empty())?; - let res = svc.oneshot(req).await?; - - let expected_max_age = inactivity_duration.whole_seconds(); - assert!(cookie_has_expected_max_age(&res, expected_max_age)); - - Ok(()) - } - - #[tokio::test] - async fn expiry_at_date_time_test() -> anyhow::Result<()> { - let session_store = MemoryStore::default(); - let expiry_time = time::OffsetDateTime::now_utc() + time::Duration::weeks(1); - let session_layer = - SessionManagerLayer::new(session_store).with_expiry(Expiry::AtDateTime(expiry_time)); - let svc = ServiceBuilder::new() - .layer(session_layer) - .service_fn(handler); - - let req = Request::builder().body(Body::empty())?; - let res = svc.oneshot(req).await?; - - let expected_max_age = (expiry_time - time::OffsetDateTime::now_utc()).whole_seconds(); - assert!(cookie_has_expected_max_age(&res, expected_max_age)); - - Ok(()) - } - - #[tokio::test] - async fn expiry_on_session_end_always_save_test() -> anyhow::Result<()> { - let session_store = MemoryStore::default(); - let session_layer = SessionManagerLayer::new(session_store.clone()) - .with_expiry(Expiry::OnSessionEnd) - .with_always_save(true); - let mut svc = ServiceBuilder::new() - .layer(session_layer) - .service_fn(handler); - - let req1 = Request::builder().body(Body::empty())?; - let res1 = svc.call(req1).await?; - let sid1 = get_session_id(&res1); - let rec1 = get_record(&session_store, &sid1).await; - let req2 = Request::builder() - .header(http::header::COOKIE, &format!("id={}", sid1)) - .body(Body::empty())?; - let res2 = svc.call(req2).await?; - let sid2 = get_session_id(&res2); - let rec2 = get_record(&session_store, &sid2).await; - - assert!(cookie_value_matches(&res2, |s| !s.contains("Max-Age"))); - assert!(sid1 == sid2); - assert!(rec1.expiry_date < rec2.expiry_date); - - Ok(()) - } - - #[tokio::test] - async fn expiry_on_inactivity_always_save_test() -> anyhow::Result<()> { - let session_store = MemoryStore::default(); - let inactivity_duration = time::Duration::hours(2); - let session_layer = SessionManagerLayer::new(session_store.clone()) - .with_expiry(Expiry::OnInactivity(inactivity_duration)) - .with_always_save(true); - let mut svc = ServiceBuilder::new() - .layer(session_layer) - .service_fn(handler); - - let req1 = Request::builder().body(Body::empty())?; - let res1 = svc.call(req1).await?; - let sid1 = get_session_id(&res1); - let rec1 = get_record(&session_store, &sid1).await; - let req2 = Request::builder() - .header(http::header::COOKIE, &format!("id={}", sid1)) - .body(Body::empty())?; - let res2 = svc.call(req2).await?; - let sid2 = get_session_id(&res2); - let rec2 = get_record(&session_store, &sid2).await; - - let expected_max_age = inactivity_duration.whole_seconds(); - assert!(cookie_has_expected_max_age(&res2, expected_max_age)); - assert!(sid1 == sid2); - assert!(rec1.expiry_date < rec2.expiry_date); - - Ok(()) - } - - #[tokio::test] - async fn expiry_at_date_time_always_save_test() -> anyhow::Result<()> { - let session_store = MemoryStore::default(); - let expiry_time = time::OffsetDateTime::now_utc() + time::Duration::weeks(1); - let session_layer = SessionManagerLayer::new(session_store.clone()) - .with_expiry(Expiry::AtDateTime(expiry_time)) - .with_always_save(true); - let mut svc = ServiceBuilder::new() - .layer(session_layer) - .service_fn(handler); - - let req1 = Request::builder().body(Body::empty())?; - let res1 = svc.call(req1).await?; - let sid1 = get_session_id(&res1); - let rec1 = get_record(&session_store, &sid1).await; - let req2 = Request::builder() - .header(http::header::COOKIE, &format!("id={}", sid1)) - .body(Body::empty())?; - let res2 = svc.call(req2).await?; - let sid2 = get_session_id(&res2); - let rec2 = get_record(&session_store, &sid2).await; - - let expected_max_age = (expiry_time - time::OffsetDateTime::now_utc()).whole_seconds(); - assert!(cookie_has_expected_max_age(&res2, expected_max_age)); - assert!(sid1 == sid2); - assert!(rec1.expiry_date == rec2.expiry_date); - - Ok(()) - } - - #[tokio::test] - async fn secure_test() -> anyhow::Result<()> { - let session_store = MemoryStore::default(); - let session_layer = SessionManagerLayer::new(session_store).with_secure(true); - let svc = ServiceBuilder::new() - .layer(session_layer) - .service_fn(handler); - - let req = Request::builder().body(Body::empty())?; - let res = svc.oneshot(req).await?; - - assert!(cookie_value_matches(&res, |s| s.contains("Secure"))); - - let session_store = MemoryStore::default(); - let session_layer = SessionManagerLayer::new(session_store).with_secure(false); - let svc = ServiceBuilder::new() - .layer(session_layer) - .service_fn(handler); - - let req = Request::builder().body(Body::empty())?; - let res = svc.oneshot(req).await?; - - assert!(cookie_value_matches(&res, |s| !s.contains("Secure"))); - - Ok(()) - } - - #[tokio::test] - async fn path_test() -> anyhow::Result<()> { - let session_store = MemoryStore::default(); - let session_layer = SessionManagerLayer::new(session_store).with_path("/foo/bar"); - let svc = ServiceBuilder::new() - .layer(session_layer) - .service_fn(handler); - - let req = Request::builder().body(Body::empty())?; - let res = svc.oneshot(req).await?; - - assert!(cookie_value_matches(&res, |s| s.contains("Path=/foo/bar"))); - - Ok(()) - } - - #[tokio::test] - async fn domain_test() -> anyhow::Result<()> { - let session_store = MemoryStore::default(); - let session_layer = SessionManagerLayer::new(session_store).with_domain("example.com"); - let svc = ServiceBuilder::new() - .layer(session_layer) - .service_fn(handler); - - let req = Request::builder().body(Body::empty())?; - let res = svc.oneshot(req).await?; - - assert!(cookie_value_matches(&res, |s| s.contains("Domain=example.com"))); - - Ok(()) - } - - #[cfg(feature = "signed")] - #[tokio::test] - async fn signed_test() -> anyhow::Result<()> { - let key = Key::generate(); - let session_store = MemoryStore::default(); - let session_layer = SessionManagerLayer::new(session_store).with_signed(key); - let svc = ServiceBuilder::new() - .layer(session_layer) - .service_fn(handler); - - let req = Request::builder().body(Body::empty())?; - let res = svc.oneshot(req).await?; - - assert!(res.headers().get(http::header::SET_COOKIE).is_some()); - - Ok(()) - } - - #[cfg(feature = "private")] - #[tokio::test] - async fn private_test() -> anyhow::Result<()> { - let key = Key::generate(); - let session_store = MemoryStore::default(); - let session_layer = SessionManagerLayer::new(session_store).with_private(key); - let svc = ServiceBuilder::new() - .layer(session_layer) - .service_fn(handler); - - let req = Request::builder().body(Body::empty())?; - let res = svc.oneshot(req).await?; - - assert!(res.headers().get(http::header::SET_COOKIE).is_some()); - - Ok(()) - } - - fn cookie_value_matches(res: &Response, matcher: F) -> bool - where - F: FnOnce(&str) -> bool, - { - res.headers() - .get(http::header::SET_COOKIE) - .is_some_and(|set_cookie| set_cookie.to_str().is_ok_and(matcher)) - } - - fn cookie_has_expected_max_age(res: &Response, expected_value: i64) -> bool { - res.headers() - .get(http::header::SET_COOKIE) - .is_some_and(|set_cookie| { - set_cookie.to_str().is_ok_and(|s| { - let max_age_value = s - .split("Max-Age=") - .nth(1) - .unwrap_or_default() - .split(';') - .next() - .unwrap_or_default() - .parse::() - .unwrap_or_default(); - (max_age_value - expected_value).abs() <= 1 - }) - }) - } - - fn get_session_id(res: &Response) -> String { - res.headers() - .get(http::header::SET_COOKIE) - .unwrap() - .to_str() - .unwrap() - .split("id=") - .nth(1) - .unwrap() - .split(";") - .next() - .unwrap() - .to_string() - } - - async fn get_record(store: &impl SessionStore, id: &str) -> Record { - store - .load(&Id::from_str(id).unwrap()) - .await - .unwrap() - .unwrap() - } -} diff --git a/src/session.rs b/src/session.rs new file mode 100644 index 0000000..fd2b9d2 --- /dev/null +++ b/src/session.rs @@ -0,0 +1,456 @@ +//! A session which allows HTTP applications to associate data with visitors. +//! +//! The structs provided here have a strict API, but they are designed to be nearly impossible to +//! misuse. Luckily, they only have a handful of methods, and all of them document how they work. +use std::{ + fmt::Debug, + mem::ManuallyDrop, + sync::{Arc, Mutex}, +}; +// TODO: Remove send + sync bounds on `R` once return type notation is stable. + +use tower_sesh_core::{expires::Expires, id::Id, Expiry, SessionStore}; + +#[derive(Debug, Clone, Copy)] +pub(crate) enum SessionUpdate { + Delete, + Set(Id, Expiry), +} + +pub(crate) type Updater = Arc>>; + +/// A session that is lazily loaded. +/// +/// This is struct provided throught a Request's Extensions by the [`SessionManager`] middleware. +/// If you happen to use `axum`, you can use this struct as an extractor since it implements +/// [`FromRequestParts`]. +/// +/// When this struct refers to the "underlying store error", it is referring to the fact that the +/// store used returned a "hard" error. For example, it could be a connection error, a protocol error, +/// a timeout, etc. A counterexample would be the [`SessionState`] not being found in the store, which is +/// not considered an error by the [`SessionStore`] trait. +/// +/// # Examples +/// - If you are using `axum`, and you have enabled the `extractor` feature, you can use this +/// struct as an extractor: +/// ```rust +/// use tower_sesh::{Session, MemoryStore}; +/// +/// async fn handler(session: Session>) -> String { +/// unimplemented!() +/// } +/// ``` +/// The extractor will error if the handler was called without a `SessionManager` middleware. +/// +/// - Otherwise, you can extract it from a request's extensions: +/// ``` +/// use tower_sesh::{Session, MemoryStore}; +/// use axum_core::{extract::Request, body::Body}; +/// +/// async fn handler(mut req: Request) -> String { +/// let Some(session) = req.extensions_mut().remove::>>() else { +/// return "No session found".to_string(); +/// }; +/// unimplemented!() +/// // ... +/// } +/// ``` +/// Again, the session will not be found if the handler was called without a `SessionManager` +/// middleware. +#[derive(Debug, Clone)] +pub struct Session { + /// This will be `None` if the handler has not received a session cookie or if the it could + /// not be parsed. + pub(crate) id: Option, + pub(crate) store: Store, + pub(crate) updater: Updater, +} + +impl Session { + /// Try to load the session from the store. + /// + /// The return type of this method looks convoluted, so let's break it down: + /// - The outer `Result` will return `Err(...)` if the underlying session store errors. + /// - Otherwise, it will return `Ok(...)`, where `...` is an `Option`. + /// - The inner `Option` will be `None` if the session was not found in the store. + /// - Otherwise, it will be `Some(...)`, where `...` is the loaded session. + /// + /// # Error + /// + /// Errors if the underlying store errors. + /// + /// # Example + /// ```rust + /// use tower_sesh::{Session, MemoryStore, Expires}; + /// + /// #[derive(Clone)] + /// struct User { + /// id: u64, + /// admin: bool, + /// } + /// + /// impl Expires for User {} + /// + /// async fn handler(session: Session>) -> String { + /// match session.load().await { + /// Ok(Some(session)) => { + /// "User has a valid session" + /// } + /// Ok(None) => { + /// "User does not have a session, redirect to login?" + /// } + /// Err(_error) => { + /// "An error occurred while loading the session" + /// } + /// }.to_string() + /// } + /// ``` + pub async fn load(mut self) -> Result>, Store::Error> + where + R: Send + Sync, + Store: SessionStore, + { + Ok(if let Some(id) = self.id { + if let Some(record) = self.store.load(&id).await? { + Some(SessionState { + store: self.store, + id, + data: record, + updater: self.updater, + }) + } else { + self.updater + .lock() + .expect("lock should not be poisoned") + .replace(SessionUpdate::Delete); + None + } + } else { + None + }) + } + + /// Create a new session with the given data, using the expiry from the data's `Expires` impl. + /// + /// # Error + /// + /// Errors if the underlying store errors. + /// + /// # Example + /// ```rust + /// use tower_sesh::{Session, MemoryStore, Expires}; + /// + /// #[derive(Clone)] + /// struct User { + /// id: u64, + /// admin: bool, + /// } + /// + /// impl Expires for User {} + /// + /// async fn handler(session: Session>) -> String { + /// let user = User { id: 1, admin: false }; + /// match session.create(user).await { + /// Ok(session) => { + /// "We have successfully created a new session with the user's id" + /// } + /// Err(_error) => { + /// "An error occurred while loading the session" + /// } + /// }.to_string() + /// } + /// ``` + pub async fn create(self, data: R) -> Result, Store::Error> + where + R: Expires + Send + Sync, + Store: SessionStore, + { + let exp = data.expires(); + self.create_with_expiry(data, exp).await + } + + /// Create a new session with the given data and expiry. See [`Session::create`] for an example. + /// + /// # Error + /// + /// Errors if the underlying store errors. + pub async fn create_with_expiry( + mut self, + data: R, + exp: Expiry, + ) -> Result, Store::Error> + where + R: Send + Sync, + Store: SessionStore, + { + let id = self.store.create(&data).await?; + self.updater + .lock() + .expect("lock should not be poisoned") + .replace(SessionUpdate::Set(id, exp)); + Ok(SessionState { + store: self.store, + id, + data, + updater: self.updater, + }) + } +} + +#[cfg(feature = "extractor")] +pub use self::extractor::*; + +#[cfg(feature = "extractor")] +mod extractor { + use super::*; + use axum_core::{ + body::Body, + extract::FromRequestParts, + response::{IntoResponse, Response}, + }; + use http::request::Parts; + + /// A rejection that is returned from the [`Session`] extractor when the [`SessionManagerLayer`] + /// middleware is not set. + #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] + #[cfg_attr(docsrs, doc(cfg(feature = "extractor")))] + pub struct NoMiddleware; + + impl std::fmt::Display for NoMiddleware { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Missing session middleware. Is it added to the app?") + } + } + + impl std::error::Error for NoMiddleware {} + + impl IntoResponse for NoMiddleware { + fn into_response(self) -> Response { + let mut resp = Response::new(Body::from(self.to_string())); + *resp.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR; + resp + } + } + + #[async_trait::async_trait] + #[cfg_attr(docsrs, doc(cfg(feature = "extractor")))] + impl FromRequestParts for Session + where + Store: Send + Sync + 'static, + { + type Rejection = NoMiddleware; + + async fn from_request_parts( + parts: &mut Parts, + _state: &State, + ) -> Result { + let session = parts + .extensions + .remove::>() + .ok_or(NoMiddleware)?; + + Ok(session) + } + } +} + +/// A loaded session. +/// +/// When this struct refers to the "underlying store error", it is referring to the fact that the +/// store used returned a "hard" error. For example, it could be a connection error, a protocol error, +/// a timeout, etc. A counterexample would be the session not being found in the store, which is +/// not considered an error by the `SessionStore` trait. +#[derive(Debug, Clone)] +pub struct SessionState { + store: Store, + id: Id, + data: R, + updater: Updater, +} + +impl SessionState { + /// Read the data associated with the session. + pub fn data(&self) -> &R { + &self.data + } +} + +impl SessionState +where + R: Send + Sync, + Store: SessionStore, +{ + /// Update the session data, returning the session if successful. + /// + /// It updates the sessions' expiry through the [`Expires`] impl. If your data does not implement + /// [`Expires`], or you want to set a different expiry, use [`DataMut::save_with_expiry`]. + /// + /// This method returns the `Session` if the data was saved successfully. It returns + /// `Ok(None)` when the session was deleted or expired between the time it was loaded and the + /// time this method is called. + /// + /// # Error + /// + /// Errors if the underlying store errors. + /// + /// # Example + /// ``` + /// use tower_sesh::{SessionState, Expires, MemoryStore}; + /// + /// #[derive(Clone)] + /// struct User { + /// id: u64, + /// admin: bool, + /// } + /// + /// impl Expires for User {} + /// + /// async fn upgrade_priviledges(state: SessionState>) -> Option { + /// let new_state = state.update(|user| { + /// user.admin = true; + /// }).await.ok()??; + /// assert!(new_state.data().admin); + /// Some("User has been upgraded to admin".to_string()) + /// } + /// ``` + pub async fn update(self, update: F) -> Result>, Store::Error> + where + F: FnOnce(&mut R), + R: Expires, + { + let exp = self.data.expires(); + self.update_with_expiry(update, exp).await + } + + /// Update the session data with a provided expiry, returning the session if successful. + /// + /// Similar to [`SessionState::update`], but allows you to set an expiry for types that don't + /// implement [`Expires`]. See [that method's documentation][SessionState::update] for more + /// information. + pub async fn update_with_expiry( + mut self, + update: F, + exp: Expiry, + ) -> Result>, Store::Error> + where + F: FnOnce(&mut R), + { + update(&mut self.data); + Ok(if self.store.save(&self.id, &self.data).await? { + self.updater + .lock() + .expect("lock should not be poisoned") + .replace(SessionUpdate::Set(self.id, exp)); + Some(self) + } else { + self.updater + .lock() + .expect("lock should not be poisoned") + .replace(SessionUpdate::Delete); + None + }) + } + + /// Delete the session from the store. + /// + /// This method returns a boolean indicating whether the session was deleted from the store. + /// If the `Store` returns `Ok(false)` if the session simply did not exist. This can happen if + /// it was deleted by another request or if the session expired between the time it was + /// loaded and the time this method was called. + /// + /// # Error + /// + /// Errors if the underlying store errors. + /// + /// # Example + /// ``` + /// use tower_sesh::{SessionState, MemoryStore, Expires}; + /// + /// #[derive(Clone)] + /// struct User; + /// + /// impl Expires for User {} + /// + /// async fn logout(state: SessionState>) -> Option { + /// Some(if state.delete().await.ok()? { + /// "User has been logged out".to_string() + /// } else { + /// "User was not logged in".to_string() + /// }) + /// } + pub async fn delete(mut self) -> Result { + let deleted = self.store.delete(&self.id).await?; + self.updater + .lock() + .expect("lock should not be poisoned") + .replace(SessionUpdate::Delete); + let _ = ManuallyDrop::new(self); + Ok(deleted) + } + + /// Cycle the session ID. + /// + /// This consumes the current session and returns a new session with the new ID. This method + /// should be used to mitigate [session fixation attacks](https://www.acrossecurity.com/papers/session_fixation.pdf). + /// + /// This method returns `Ok(None)` if the session was deleted or expired between the time it + /// was loaded and the time this method was called. Otherwise, it returns the new + /// `Some(Session)`. + /// + /// # Error + /// + /// Errors if the underlying store errors. + /// + /// # Example + /// ``` + /// use tower_sesh::{SessionState, MemoryStore, Expires}; + /// + /// #[derive(Clone)] + /// struct User; + /// + /// impl Expires for User {} + /// + /// async fn cycle(state: SessionState>) -> Option { + /// Some(if let Some(new_state) = state.cycle().await.ok()? { + /// "Session has been cycled".to_string() + /// } else { + /// "Session was not found".to_string() + /// }) + /// } + /// ``` + pub async fn cycle(self) -> Result>, Store::Error> + where + R: Expires, + { + let exp = self.data.expires(); + self.cycle_with_expiry(exp).await + } + + /// Cycle the session ID with a provided expiry, instead of the one from the [`Expires`] trait. + /// + /// Similar to [`SessionState::cycle`], but allows you to set an expiry for types that don't + /// implement [`Expires`]. See [that method's documentation][SessionState::cycle] for more information. + pub async fn cycle_with_expiry( + mut self, + exp: Expiry, + ) -> Result>, Store::Error> { + if let Some(new_id) = self.store.cycle_id(&self.id).await? { + self.updater + .lock() + .expect("lock should not be poisoned") + .replace(SessionUpdate::Set(new_id, exp)); + self.id = new_id; + return Ok(Some(self)); + } + self.updater + .lock() + .expect("lock should not be poisoned") + .replace(SessionUpdate::Delete); + Ok(None) + } + + /// Get the session store. + pub fn into_store(self) -> Store { + self.store + } +} diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 134b9d6..09a79e0 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -3,462 +3,453 @@ use axum_core::body::Body; use http::{header, HeaderMap}; use http_body_util::BodyExt; use time::{Duration, OffsetDateTime}; -use tower_cookies::{cookie, Cookie}; -use tower_sessions::{Expiry, Session, SessionManagerLayer, SessionStore}; +use tower_sesh::{Expires, Expiry, Session, SessionManagerLayer, SessionStore}; -fn routes() -> Router { + +#[derive(Debug, Clone, Copy)] +pub struct Foo(pub u32); + +impl Expires for Foo { + fn expires(&self) -> Expiry { + Expiry::OnInactivity(Duration::hours(1)) + } +} + +fn routes() -> Router where +Store: SessionStore + Clone + 'static, +Store::Error: std::fmt::Debug, +{ + Router::new() - .route("/", get(|_: Session| async move { "Hello, world!" })) + .route("/", get(|_: Session| async move { "Hello, world!" })) .route( - "/insert", - get(|session: Session| async move { - session.insert("foo", 42).await.unwrap(); + "/create", + get(|session: Session| async move { + session.create(Foo(42)).await.unwrap(); }), ) .route( "/get", - get(|session: Session| async move { - format!("{}", session.get::("foo").await.unwrap().unwrap()) - }), - ) - .route( - "/get_value", - get(|session: Session| async move { - format!("{:?}", session.get_value("foo").await.unwrap()) + get(|session: Session| async move { + format!("{:?}", session.load().await.unwrap().unwrap().data()) }), ) .route( "/remove", - get(|session: Session| async move { - session.remove::("foo").await.unwrap(); - }), - ) - .route( - "/remove_value", - get(|session: Session| async move { - session.remove_value("foo").await.unwrap(); + get(|session: Session| async move { + let state = session.load().await.unwrap().unwrap(); + println!("{}", state.delete().await.unwrap()); }), ) .route( "/cycle_id", - get(|session: Session| async move { - session.cycle_id().await.unwrap(); - }), - ) - .route( - "/flush", - get(|session: Session| async move { - session.flush().await.unwrap(); + get(|session: Session| async move { + let state = session.load().await.unwrap().unwrap(); + state.cycle().await.unwrap(); }), ) .route( "/set_expiry", - get(|session: Session| async move { + get(|session: Session| async move { let expiry = Expiry::AtDateTime(OffsetDateTime::now_utc() + Duration::days(1)); - session.set_expiry(Some(expiry)); - }), - ) - .route( - "/remove_expiry", - get(|session: Session| async move { - session.set_expiry(Some(Expiry::OnSessionEnd)); + session.load().await.unwrap().unwrap().update_with_expiry(|_| {}, expiry).await.unwrap(); }), ) } -pub fn build_app( - mut session_manager: SessionManagerLayer, - max_age: Option, - domain: Option, -) -> Router { - if let Some(max_age) = max_age { - session_manager = session_manager.with_expiry(Expiry::OnInactivity(max_age)); - } - - if let Some(domain) = domain { - session_manager = session_manager.with_domain(domain); - } - - routes().layer(session_manager) -} - -pub async fn body_string(body: Body) -> String { - let bytes = body.collect().await.unwrap().to_bytes(); - String::from_utf8_lossy(&bytes).into() -} - -pub fn get_session_cookie(headers: &HeaderMap) -> Result, cookie::ParseError> { - headers - .get_all(header::SET_COOKIE) - .iter() - .flat_map(|header| header.to_str()) - .next() - .ok_or(cookie::ParseError::MissingPair) - .and_then(Cookie::parse_encoded) -} - -#[macro_export] -macro_rules! route_tests { - ($create_app:expr) => { - use axum::body::Body; - use http::{header, Request, StatusCode}; - use time::Duration; - use tower::ServiceExt; - use tower_cookies::{cookie::SameSite, Cookie}; - use $crate::common::{body_string, get_session_cookie}; - - #[tokio::test] - async fn no_session_set() { - let req = Request::builder().uri("/").body(Body::empty()).unwrap(); - let res = $create_app(Some(Duration::hours(1)), None) - .await - .oneshot(req) - .await - .unwrap(); - - assert!(res - .headers() - .get_all(header::SET_COOKIE) - .iter() - .next() - .is_none()); - } - - #[tokio::test] - async fn bogus_session_cookie() { - let session_cookie = Cookie::new("id", "AAAAAAAAAAAAAAAAAAAAAA"); - let req = Request::builder() - .uri("/insert") - .header(header::COOKIE, session_cookie.encoded().to_string()) - .body(Body::empty()) - .unwrap(); - let res = $create_app(Some(Duration::hours(1)), None) - .await - .oneshot(req) - .await - .unwrap(); - let session_cookie = get_session_cookie(res.headers()).unwrap(); - - assert_eq!(res.status(), StatusCode::OK); - assert_ne!(session_cookie.value(), "AAAAAAAAAAAAAAAAAAAAAA"); - } - - #[tokio::test] - async fn malformed_session_cookie() { - let session_cookie = Cookie::new("id", "malformed"); - let req = Request::builder() - .uri("/") - .header(header::COOKIE, session_cookie.encoded().to_string()) - .body(Body::empty()) - .unwrap(); - let res = $create_app(Some(Duration::hours(1)), None) - .await - .oneshot(req) - .await - .unwrap(); - - let session_cookie = get_session_cookie(res.headers()).unwrap(); - assert_ne!(session_cookie.value(), "malformed"); - assert_eq!(res.status(), StatusCode::OK); - } - - #[tokio::test] - async fn insert_session() { - let req = Request::builder() - .uri("/insert") - .body(Body::empty()) - .unwrap(); - let res = $create_app(Some(Duration::hours(1)), None) - .await - .oneshot(req) - .await - .unwrap(); - let session_cookie = get_session_cookie(res.headers()).unwrap(); - - assert_eq!(session_cookie.name(), "id"); - assert_eq!(session_cookie.http_only(), Some(true)); - assert_eq!(session_cookie.same_site(), Some(SameSite::Strict)); - assert!(session_cookie - .max_age() - .is_some_and(|dt| dt <= Duration::hours(1))); - assert_eq!(session_cookie.secure(), Some(true)); - assert_eq!(session_cookie.path(), Some("/")); - } - - #[tokio::test] - async fn session_max_age() { - let req = Request::builder() - .uri("/insert") - .body(Body::empty()) - .unwrap(); - let res = $create_app(None, None).await.oneshot(req).await.unwrap(); - let session_cookie = get_session_cookie(res.headers()).unwrap(); - - assert_eq!(session_cookie.name(), "id"); - assert_eq!(session_cookie.http_only(), Some(true)); - assert_eq!(session_cookie.same_site(), Some(SameSite::Strict)); - assert!(session_cookie.max_age().is_none()); - assert_eq!(session_cookie.secure(), Some(true)); - assert_eq!(session_cookie.path(), Some("/")); - } - - #[tokio::test] - async fn get_session() { - let app = $create_app(Some(Duration::hours(1)), None).await; - - let req = Request::builder() - .uri("/insert") - .body(Body::empty()) - .unwrap(); - let res = app.clone().oneshot(req).await.unwrap(); - let session_cookie = get_session_cookie(res.headers()).unwrap(); - - let req = Request::builder() - .uri("/get") - .header(header::COOKIE, session_cookie.encoded().to_string()) - .body(Body::empty()) - .unwrap(); - let res = app.oneshot(req).await.unwrap(); - assert_eq!(res.status(), StatusCode::OK); - - assert_eq!(body_string(res.into_body()).await, "42"); - } - - #[tokio::test] - async fn get_no_value() { - let app = $create_app(Some(Duration::hours(1)), None).await; - - let req = Request::builder() - .uri("/get_value") - .body(Body::empty()) - .unwrap(); - let res = app.oneshot(req).await.unwrap(); - - assert_eq!(body_string(res.into_body()).await, "None"); - } - - #[tokio::test] - async fn remove_last_value() { - let app = $create_app(Some(Duration::hours(1)), None).await; - - let req = Request::builder() - .uri("/insert") - .body(Body::empty()) - .unwrap(); - let res = app.clone().oneshot(req).await.unwrap(); - let session_cookie = get_session_cookie(res.headers()).unwrap(); - - let req = Request::builder() - .uri("/remove_value") - .header(header::COOKIE, session_cookie.encoded().to_string()) - .body(Body::empty()) - .unwrap(); - app.clone().oneshot(req).await.unwrap(); - - let req = Request::builder() - .uri("/get_value") - .header(header::COOKIE, session_cookie.encoded().to_string()) - .body(Body::empty()) - .unwrap(); - let res = app.oneshot(req).await.unwrap(); - - assert_eq!(body_string(res.into_body()).await, "None"); - } - - #[tokio::test] - async fn cycle_session_id() { - let app = $create_app(Some(Duration::hours(1)), None).await; - - let req = Request::builder() - .uri("/insert") - .body(Body::empty()) - .unwrap(); - let res = app.clone().oneshot(req).await.unwrap(); - let first_session_cookie = get_session_cookie(res.headers()).unwrap(); - - let req = Request::builder() - .uri("/cycle_id") - .header(header::COOKIE, first_session_cookie.encoded().to_string()) - .body(Body::empty()) - .unwrap(); - let res = app.clone().oneshot(req).await.unwrap(); - let second_session_cookie = get_session_cookie(res.headers()).unwrap(); - - let req = Request::builder() - .uri("/get") - .header(header::COOKIE, second_session_cookie.encoded().to_string()) - .body(Body::empty()) - .unwrap(); - let res = dbg!(app.oneshot(req).await).unwrap(); - - assert_ne!(first_session_cookie.value(), second_session_cookie.value()); - assert_eq!(body_string(res.into_body()).await, "42"); - } - - #[tokio::test] - async fn flush_session() { - let app = $create_app(Some(Duration::hours(1)), None).await; - - let req = Request::builder() - .uri("/insert") - .body(Body::empty()) - .unwrap(); - let res = app.clone().oneshot(req).await.unwrap(); - let session_cookie = get_session_cookie(res.headers()).unwrap(); - - let req = Request::builder() - .uri("/flush") - .header(header::COOKIE, session_cookie.encoded().to_string()) - .body(Body::empty()) - .unwrap(); - let res = app.oneshot(req).await.unwrap(); - - let session_cookie = get_session_cookie(res.headers()).unwrap(); - - assert_eq!(session_cookie.value(), ""); - assert_eq!(session_cookie.max_age(), Some(Duration::ZERO)); - assert_eq!(session_cookie.path(), Some("/")); - } - - #[tokio::test] - async fn flush_with_domain() { - let app = $create_app(Some(Duration::hours(1)), Some("localhost".to_string())).await; - - let req = Request::builder() - .uri("/insert") - .body(Body::empty()) - .unwrap(); - let res = app.clone().oneshot(req).await.unwrap(); - let session_cookie = get_session_cookie(res.headers()).unwrap(); - - let req = Request::builder() - .uri("/flush") - .header(header::COOKIE, session_cookie.encoded().to_string()) - .body(Body::empty()) - .unwrap(); - let res = app.oneshot(req).await.unwrap(); - - let session_cookie = get_session_cookie(res.headers()).unwrap(); - - assert_eq!(session_cookie.value(), ""); - assert_eq!(session_cookie.max_age(), Some(Duration::ZERO)); - assert_eq!(session_cookie.domain(), Some("localhost")); - assert_eq!(session_cookie.path(), Some("/")); - } - - #[tokio::test] - async fn set_expiry() { - let app = $create_app(Some(Duration::hours(1)), Some("localhost".to_string())).await; - - let req = Request::builder() - .uri("/insert") - .body(Body::empty()) - .unwrap(); - let res = app.clone().oneshot(req).await.unwrap(); - let session_cookie = get_session_cookie(res.headers()).unwrap(); - - let expected_duration = Duration::hours(1); - let actual_duration = session_cookie.max_age().unwrap(); - let tolerance = Duration::seconds(1); - - assert!( - actual_duration >= expected_duration - tolerance - && actual_duration <= expected_duration + tolerance, - "Duration is not within the acceptable range: {:?}", - actual_duration - ); - - let req = Request::builder() - .uri("/set_expiry") - .header(header::COOKIE, session_cookie.encoded().to_string()) - .body(Body::empty()) - .unwrap(); - let res = app.oneshot(req).await.unwrap(); - - let session_cookie = get_session_cookie(res.headers()).unwrap(); - - let expected_duration = Duration::days(1); - let actual_duration = session_cookie.max_age().unwrap(); - let tolerance = Duration::seconds(1); - - assert!( - actual_duration >= expected_duration - tolerance - && actual_duration <= expected_duration + tolerance, - "Duration is not within the acceptable range: {:?}", - actual_duration - ); - } - - #[tokio::test] - async fn change_expiry_type() { - let app = $create_app(None, Some("localhost".to_string())).await; - - let req = Request::builder() - .uri("/insert") - .body(Body::empty()) - .unwrap(); - let res = app.clone().oneshot(req).await.unwrap(); - let session_cookie = get_session_cookie(res.headers()).unwrap(); - - let expected_duration = None; - let actual_duration = session_cookie.max_age(); - - assert_eq!(actual_duration, expected_duration, "Duration is not None"); - - let req = Request::builder() - .uri("/set_expiry") - .header(header::COOKIE, session_cookie.encoded().to_string()) - .body(Body::empty()) - .unwrap(); - let res = app.oneshot(req).await.unwrap(); - - let session_cookie = get_session_cookie(res.headers()).unwrap(); - - let expected_duration = Duration::days(1); - assert!(session_cookie.max_age().is_some(), "Duration is None"); - let actual_duration = session_cookie.max_age().unwrap(); - let tolerance = Duration::seconds(1); - - assert!( - actual_duration >= expected_duration - tolerance - && actual_duration <= expected_duration + tolerance, - "Duration is not within the acceptable range: {:?}", - actual_duration - ); - - let app2 = $create_app(Some(Duration::hours(1)), Some("localhost".to_string())).await; - - let req = Request::builder() - .uri("/insert") - .body(Body::empty()) - .unwrap(); - let res = app2.clone().oneshot(req).await.unwrap(); - let session_cookie = get_session_cookie(res.headers()).unwrap(); - - let expected_duration = Duration::hours(1); - let actual_duration = session_cookie.max_age().unwrap(); - let tolerance = Duration::seconds(1); - - assert!( - actual_duration >= expected_duration - tolerance - && actual_duration <= expected_duration + tolerance, - "Duration is not within the acceptable range: {:?}", - actual_duration - ); - - let req = Request::builder() - .uri("/remove_expiry") - .header(header::COOKIE, session_cookie.encoded().to_string()) - .body(Body::empty()) - .unwrap(); - let res = app2.oneshot(req).await.unwrap(); - - let session_cookie = get_session_cookie(res.headers()).unwrap(); - - let expected_duration = None; - let actual_duration = session_cookie.max_age(); - - assert_eq!(actual_duration, expected_duration, "Duration is not None"); - } - }; -} +// pub fn build_app( +// mut session_manager: SessionManagerLayer, +// max_age: Option, +// domain: Option, +// ) -> Router { +// if let Some(max_age) = max_age { +// session_manager = session_manager.with_expiry(Expiry::OnInactivity(max_age)); +// } +// +// if let Some(domain) = domain { +// session_manager = session_manager.with_domain(domain); +// } +// +// routes().layer(session_manager) +// } +// +// pub async fn body_string(body: Body) -> String { +// let bytes = body.collect().await.unwrap().to_bytes(); +// String::from_utf8_lossy(&bytes).into() +// } +// +// pub fn get_session_cookie(headers: &HeaderMap) -> Result, cookie::ParseError> { +// headers +// .get_all(header::SET_COOKIE) +// .iter() +// .flat_map(|header| header.to_str()) +// .next() +// .ok_or(cookie::ParseError::MissingPair) +// .and_then(Cookie::parse_encoded) +// } +// +// #[macro_export] +// macro_rules! route_tests { +// ($create_app:expr) => { +// use axum::body::Body; +// use http::{header, Request, StatusCode}; +// use time::Duration; +// use tower::ServiceExt; +// use tower_cookies::{cookie::SameSite, Cookie}; +// use $crate::common::{body_string, get_session_cookie}; +// +// #[tokio::test] +// async fn no_session_set() { +// let req = Request::builder().uri("/").body(Body::empty()).unwrap(); +// let res = $create_app(Some(Duration::hours(1)), None) +// .await +// .oneshot(req) +// .await +// .unwrap(); +// +// assert!(res +// .headers() +// .get_all(header::SET_COOKIE) +// .iter() +// .next() +// .is_none()); +// } +// +// #[tokio::test] +// async fn bogus_session_cookie() { +// let session_cookie = Cookie::new("id", "AAAAAAAAAAAAAAAAAAAAAA"); +// let req = Request::builder() +// .uri("/insert") +// .header(header::COOKIE, session_cookie.encoded().to_string()) +// .body(Body::empty()) +// .unwrap(); +// let res = $create_app(Some(Duration::hours(1)), None) +// .await +// .oneshot(req) +// .await +// .unwrap(); +// let session_cookie = get_session_cookie(res.headers()).unwrap(); +// +// assert_eq!(res.status(), StatusCode::OK); +// assert_ne!(session_cookie.value(), "AAAAAAAAAAAAAAAAAAAAAA"); +// } +// +// #[tokio::test] +// async fn malformed_session_cookie() { +// let session_cookie = Cookie::new("id", "malformed"); +// let req = Request::builder() +// .uri("/") +// .header(header::COOKIE, session_cookie.encoded().to_string()) +// .body(Body::empty()) +// .unwrap(); +// let res = $create_app(Some(Duration::hours(1)), None) +// .await +// .oneshot(req) +// .await +// .unwrap(); +// +// let session_cookie = get_session_cookie(res.headers()).unwrap(); +// assert_ne!(session_cookie.value(), "malformed"); +// assert_eq!(res.status(), StatusCode::OK); +// } +// +// #[tokio::test] +// async fn insert_session() { +// let req = Request::builder() +// .uri("/insert") +// .body(Body::empty()) +// .unwrap(); +// let res = $create_app(Some(Duration::hours(1)), None) +// .await +// .oneshot(req) +// .await +// .unwrap(); +// let session_cookie = get_session_cookie(res.headers()).unwrap(); +// +// assert_eq!(session_cookie.name(), "id"); +// assert_eq!(session_cookie.http_only(), Some(true)); +// assert_eq!(session_cookie.same_site(), Some(SameSite::Strict)); +// assert!(session_cookie +// .max_age() +// .is_some_and(|dt| dt <= Duration::hours(1))); +// assert_eq!(session_cookie.secure(), Some(true)); +// assert_eq!(session_cookie.path(), Some("/")); +// } +// +// #[tokio::test] +// async fn session_max_age() { +// let req = Request::builder() +// .uri("/insert") +// .body(Body::empty()) +// .unwrap(); +// let res = $create_app(None, None).await.oneshot(req).await.unwrap(); +// let session_cookie = get_session_cookie(res.headers()).unwrap(); +// +// assert_eq!(session_cookie.name(), "id"); +// assert_eq!(session_cookie.http_only(), Some(true)); +// assert_eq!(session_cookie.same_site(), Some(SameSite::Strict)); +// assert!(session_cookie.max_age().is_none()); +// assert_eq!(session_cookie.secure(), Some(true)); +// assert_eq!(session_cookie.path(), Some("/")); +// } +// +// #[tokio::test] +// async fn get_session() { +// let app = $create_app(Some(Duration::hours(1)), None).await; +// +// let req = Request::builder() +// .uri("/insert") +// .body(Body::empty()) +// .unwrap(); +// let res = app.clone().oneshot(req).await.unwrap(); +// let session_cookie = get_session_cookie(res.headers()).unwrap(); +// +// let req = Request::builder() +// .uri("/get") +// .header(header::COOKIE, session_cookie.encoded().to_string()) +// .body(Body::empty()) +// .unwrap(); +// let res = app.oneshot(req).await.unwrap(); +// assert_eq!(res.status(), StatusCode::OK); +// +// assert_eq!(body_string(res.into_body()).await, "42"); +// } +// +// #[tokio::test] +// async fn get_no_value() { +// let app = $create_app(Some(Duration::hours(1)), None).await; +// +// let req = Request::builder() +// .uri("/get_value") +// .body(Body::empty()) +// .unwrap(); +// let res = app.oneshot(req).await.unwrap(); +// +// assert_eq!(body_string(res.into_body()).await, "None"); +// } +// +// #[tokio::test] +// async fn remove_last_value() { +// let app = $create_app(Some(Duration::hours(1)), None).await; +// +// let req = Request::builder() +// .uri("/insert") +// .body(Body::empty()) +// .unwrap(); +// let res = app.clone().oneshot(req).await.unwrap(); +// let session_cookie = get_session_cookie(res.headers()).unwrap(); +// +// let req = Request::builder() +// .uri("/remove_value") +// .header(header::COOKIE, session_cookie.encoded().to_string()) +// .body(Body::empty()) +// .unwrap(); +// app.clone().oneshot(req).await.unwrap(); +// +// let req = Request::builder() +// .uri("/get_value") +// .header(header::COOKIE, session_cookie.encoded().to_string()) +// .body(Body::empty()) +// .unwrap(); +// let res = app.oneshot(req).await.unwrap(); +// +// assert_eq!(body_string(res.into_body()).await, "None"); +// } +// +// #[tokio::test] +// async fn cycle_session_id() { +// let app = $create_app(Some(Duration::hours(1)), None).await; +// +// let req = Request::builder() +// .uri("/insert") +// .body(Body::empty()) +// .unwrap(); +// let res = app.clone().oneshot(req).await.unwrap(); +// let first_session_cookie = get_session_cookie(res.headers()).unwrap(); +// +// let req = Request::builder() +// .uri("/cycle_id") +// .header(header::COOKIE, first_session_cookie.encoded().to_string()) +// .body(Body::empty()) +// .unwrap(); +// let res = app.clone().oneshot(req).await.unwrap(); +// let second_session_cookie = get_session_cookie(res.headers()).unwrap(); +// +// let req = Request::builder() +// .uri("/get") +// .header(header::COOKIE, second_session_cookie.encoded().to_string()) +// .body(Body::empty()) +// .unwrap(); +// let res = dbg!(app.oneshot(req).await).unwrap(); +// +// assert_ne!(first_session_cookie.value(), second_session_cookie.value()); +// assert_eq!(body_string(res.into_body()).await, "42"); +// } +// +// #[tokio::test] +// async fn flush_session() { +// let app = $create_app(Some(Duration::hours(1)), None).await; +// +// let req = Request::builder() +// .uri("/insert") +// .body(Body::empty()) +// .unwrap(); +// let res = app.clone().oneshot(req).await.unwrap(); +// let session_cookie = get_session_cookie(res.headers()).unwrap(); +// +// let req = Request::builder() +// .uri("/flush") +// .header(header::COOKIE, session_cookie.encoded().to_string()) +// .body(Body::empty()) +// .unwrap(); +// let res = app.oneshot(req).await.unwrap(); +// +// let session_cookie = get_session_cookie(res.headers()).unwrap(); +// +// assert_eq!(session_cookie.value(), ""); +// assert_eq!(session_cookie.max_age(), Some(Duration::ZERO)); +// assert_eq!(session_cookie.path(), Some("/")); +// } +// +// #[tokio::test] +// async fn flush_with_domain() { +// let app = $create_app(Some(Duration::hours(1)), Some("localhost".to_string())).await; +// +// let req = Request::builder() +// .uri("/insert") +// .body(Body::empty()) +// .unwrap(); +// let res = app.clone().oneshot(req).await.unwrap(); +// let session_cookie = get_session_cookie(res.headers()).unwrap(); +// +// let req = Request::builder() +// .uri("/flush") +// .header(header::COOKIE, session_cookie.encoded().to_string()) +// .body(Body::empty()) +// .unwrap(); +// let res = app.oneshot(req).await.unwrap(); +// +// let session_cookie = get_session_cookie(res.headers()).unwrap(); +// +// assert_eq!(session_cookie.value(), ""); +// assert_eq!(session_cookie.max_age(), Some(Duration::ZERO)); +// assert_eq!(session_cookie.domain(), Some("localhost")); +// assert_eq!(session_cookie.path(), Some("/")); +// } +// +// #[tokio::test] +// async fn set_expiry() { +// let app = $create_app(Some(Duration::hours(1)), Some("localhost".to_string())).await; +// +// let req = Request::builder() +// .uri("/insert") +// .body(Body::empty()) +// .unwrap(); +// let res = app.clone().oneshot(req).await.unwrap(); +// let session_cookie = get_session_cookie(res.headers()).unwrap(); +// +// let expected_duration = Duration::hours(1); +// let actual_duration = session_cookie.max_age().unwrap(); +// let tolerance = Duration::seconds(1); +// +// assert!( +// actual_duration >= expected_duration - tolerance +// && actual_duration <= expected_duration + tolerance, +// "Duration is not within the acceptable range: {:?}", +// actual_duration +// ); +// +// let req = Request::builder() +// .uri("/set_expiry") +// .header(header::COOKIE, session_cookie.encoded().to_string()) +// .body(Body::empty()) +// .unwrap(); +// let res = app.oneshot(req).await.unwrap(); +// +// let session_cookie = get_session_cookie(res.headers()).unwrap(); +// +// let expected_duration = Duration::days(1); +// let actual_duration = session_cookie.max_age().unwrap(); +// let tolerance = Duration::seconds(1); +// +// assert!( +// actual_duration >= expected_duration - tolerance +// && actual_duration <= expected_duration + tolerance, +// "Duration is not within the acceptable range: {:?}", +// actual_duration +// ); +// } +// +// #[tokio::test] +// async fn change_expiry_type() { +// let app = $create_app(None, Some("localhost".to_string())).await; +// +// let req = Request::builder() +// .uri("/insert") +// .body(Body::empty()) +// .unwrap(); +// let res = app.clone().oneshot(req).await.unwrap(); +// let session_cookie = get_session_cookie(res.headers()).unwrap(); +// +// let expected_duration = None; +// let actual_duration = session_cookie.max_age(); +// +// assert_eq!(actual_duration, expected_duration, "Duration is not None"); +// +// let req = Request::builder() +// .uri("/set_expiry") +// .header(header::COOKIE, session_cookie.encoded().to_string()) +// .body(Body::empty()) +// .unwrap(); +// let res = app.oneshot(req).await.unwrap(); +// +// let session_cookie = get_session_cookie(res.headers()).unwrap(); +// +// let expected_duration = Duration::days(1); +// assert!(session_cookie.max_age().is_some(), "Duration is None"); +// let actual_duration = session_cookie.max_age().unwrap(); +// let tolerance = Duration::seconds(1); +// +// assert!( +// actual_duration >= expected_duration - tolerance +// && actual_duration <= expected_duration + tolerance, +// "Duration is not within the acceptable range: {:?}", +// actual_duration +// ); +// +// let app2 = $create_app(Some(Duration::hours(1)), Some("localhost".to_string())).await; +// +// let req = Request::builder() +// .uri("/insert") +// .body(Body::empty()) +// .unwrap(); +// let res = app2.clone().oneshot(req).await.unwrap(); +// let session_cookie = get_session_cookie(res.headers()).unwrap(); +// +// let expected_duration = Duration::hours(1); +// let actual_duration = session_cookie.max_age().unwrap(); +// let tolerance = Duration::seconds(1); +// +// assert!( +// actual_duration >= expected_duration - tolerance +// && actual_duration <= expected_duration + tolerance, +// "Duration is not within the acceptable range: {:?}", +// actual_duration +// ); +// +// let req = Request::builder() +// .uri("/remove_expiry") +// .header(header::COOKIE, session_cookie.encoded().to_string()) +// .body(Body::empty()) +// .unwrap(); +// let res = app2.oneshot(req).await.unwrap(); +// +// let session_cookie = get_session_cookie(res.headers()).unwrap(); +// +// let expected_duration = None; +// let actual_duration = session_cookie.max_age(); +// +// assert_eq!(actual_duration, expected_duration, "Duration is not None"); +// } +// }; +// } diff --git a/tests/integration-tests.rs b/tests/integration-tests.rs index 5a30d22..aad31d8 100644 --- a/tests/integration-tests.rs +++ b/tests/integration-tests.rs @@ -1,18 +1,18 @@ -#[macro_use] -mod common; +// #[macro_use] +// mod common; -#[cfg(all(test, feature = "axum-core", feature = "memory-store"))] -mod memory_store_tests { - use axum::Router; - use tower_sessions::{MemoryStore, SessionManagerLayer}; - - use crate::common::build_app; - - async fn app(max_age: Option, domain: Option) -> Router { - let session_store = MemoryStore::default(); - let session_manager = SessionManagerLayer::new(session_store).with_secure(true); - build_app(session_manager, max_age, domain) - } - - route_tests!(app); -} +// #[cfg(all(test, feature = "extractor", feature = "memory-store"))] +// mod memory_store_tests { +// use axum::Router; +// use tower_sesh_core::{MemoryStore, SessionManagerLayer}; +// +// use crate::common::build_app; +// +// async fn app(max_age: Option, domain: Option) -> Router { +// let session_store = MemoryStore::default(); +// let session_manager = SessionManagerLayer::new(session_store).with_secure(true); +// build_app(session_manager, max_age, domain) +// } +// +// route_tests!(app); +// } diff --git a/tower-sesh-core/Cargo.toml b/tower-sesh-core/Cargo.toml new file mode 100644 index 0000000..7d692b8 --- /dev/null +++ b/tower-sesh-core/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "tower-sesh-core" +description = "Core types and traits for tower-sesh." +documentation.workspace = true +version.workspace = true +license.workspace = true +edition.workspace = true +authors.workspace = true +repository.workspace = true + +[features] + +[dependencies] +time = { version = "0.3.36", features = ["serde"] } +base64 = "0.22.0" +futures-util = { version = "0.3.30", default-features = false } +serde = { version = "1.0.210", features = ["derive"] } +either = "1.13" + +[dev-dependencies] +tower-sesh = { workspace = true, features = [] } +tokio-test = "0.4.3" +tokio = { workspace = true, features = ["rt", "macros"] } +mockall = "0.13.0" diff --git a/tower-sesh-core/src/expires.rs b/tower-sesh-core/src/expires.rs new file mode 100644 index 0000000..8212299 --- /dev/null +++ b/tower-sesh-core/src/expires.rs @@ -0,0 +1,108 @@ +use serde::{Deserialize, Serialize}; + +/// Trait for types that can expire. +/// +/// If a [`SessionStore`][crate::SessionStore] does session expiration management, +/// it should rely on this trait to access a record's expiration. +/// +/// If a [`SessionStore`][crate::SessionStore] implementation relies on this trait, then it should +/// also check the expiration of a record every time it is saved, and it should update the +/// record's expiration on the backend accordingly. +/// +/// # Examples +/// - A record that should not expire: +/// ``` +/// use tower_sesh_core::{Expires, Expiry}; +/// +/// struct NeverExpires; +/// +/// impl Expires for NeverExpires {} +/// ``` +/// +/// - A record that should expire after 5 minutes of inactivity: +/// ``` +/// use time::{Duration, OffsetDateTime}; +/// use tower_sesh_core::{Expires, Expiry}; +/// +/// struct ExpiresAfter5Minutes; +/// +/// impl Expires for ExpiresAfter5Minutes { +/// fn expires(&self) -> Expiry { +/// Expiry::OnInactivity(Duration::minutes(5)) +/// // OR +/// // Expiry::OnInactivity(OffsetDateTime::now_utc() + Duration::minutes(5)); +/// } +/// } +/// ``` +/// +/// - A record that keeps track of its own expiration: +/// ``` +/// use time::{Duration, OffsetDateTime}; +/// use tower_sesh_core::{Expires, Expiry}; +/// +/// struct CustomExpiry { +/// expiry: Expiry, +/// } +/// +/// impl Expires for CustomExpiry { +/// fn expires(&self) -> Expiry { +/// self.expiry +/// } +/// } +/// ``` +pub trait Expires { + /// Returns the expiration of the record. + /// + /// By default, this always returns [`Expiry::OnSessionEnd`]. If the record should expire, then + /// one needs to implement this method. + fn expires(&self) -> Expiry { + Expiry::OnSessionEnd + } +} + +/// Session expiry configuration. +/// +/// # Examples +/// +/// ```rust +/// use time::{Duration, OffsetDateTime}; +/// use tower_sesh_core::Expiry; +/// +/// // Will be expired on "session end". +/// let expiry = Expiry::OnSessionEnd; +/// +/// // Will be expired in five minutes from last acitve. +/// let expiry = Expiry::OnInactivity(Duration::minutes(5)); +/// +/// // Will be expired at the given timestamp. +/// let expired_at = OffsetDateTime::now_utc().saturating_add(Duration::weeks(2)); +/// let expiry = Expiry::AtDateTime(expired_at); +/// ``` +#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Expiry { + /// __Browser:__ Expire on [current session end][current-session-end], as defined by the + /// browser. + /// + /// __Server:__ No expiration is set. + /// + /// [current-session-end]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#removal_defining_the_lifetime_of_a_cookie + OnSessionEnd, + + /// Expire on inactivity. + /// + /// Reading a session is not considered activity for expiration purposes. Expiration + /// is computed from the last time the session was modified. That is, when + /// the session is created ([`SessionStore::create`]), when it is saved + /// ([`SessionStore::save`]/[`SessionStore::save_or_create`]), and when its [`Id`] is cycled + /// ([`SessionStore::cycle_id`]). + /// + /// [`Id`]: crate::Id + /// [`SessionStore::create`]: crate::SessionStore::create + /// [`SessionStore::save`]: crate::SessionStore::save + /// [`SessionStore::save_or_create`]: crate::SessionStore::save_or_create + /// [`SessionStore::cycle_id`]: crate::SessionStore::cycle_id + OnInactivity(time::Duration), + + /// Expire at a specific date and time. + AtDateTime(time::OffsetDateTime), +} diff --git a/tower-sesh-core/src/id.rs b/tower-sesh-core/src/id.rs new file mode 100644 index 0000000..27c1500 --- /dev/null +++ b/tower-sesh-core/src/id.rs @@ -0,0 +1,57 @@ +//! Module for session IDs. +//! +//! [`Id`] rigourously follows the [OWASP Session Management +//! Guidelines][owasp]. +//! +//! [owasp]: +//! https://owasp.org/www-project-cheat-sheets/cheatsheets/Session_Management_Cheat_Sheet.html#session-id-entropy +use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _}; +use base64::DecodeError; +use serde::{Deserialize, Serialize}; +use std::fmt::{self, Display}; +use std::str::FromStr; + + +/// ID type for sessions. +/// +/// Wraps an array of 16 bytes. +/// +/// __Warning:__ This should be constructed [using a strong CSPRNG][csprng]. Ideally, the ID should +/// be generated by the underlying database if it provides a secure RNG source. +/// +/// If a [`SessionStore`] needs to generate IDs, it should use the [`rand`] crate, and it should +/// upstream the decision of which `Rng` provider to use through a generic parameter. +/// +/// [`SessionStore`]: crate::SessionStore +/// [csprng]: https://en.wikipedia.org/wiki/Cryptographically_secure_pseudorandom_number_generator +/// [`rand`]: https://crates.io/crates/rand +/// ``` +#[derive(Copy, Clone, Debug, Deserialize, Serialize, Eq, PartialEq, PartialOrd, Ord, Hash)] +pub struct Id(pub u128); + +impl Display for Id { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut encoded = [0; 22]; + URL_SAFE_NO_PAD + .encode_slice(self.0.to_le_bytes(), &mut encoded) + .expect("Encoded ID must be exactly 22 bytes"); + let encoded = std::str::from_utf8(&encoded).expect("Encoded ID must be valid UTF-8"); + + f.write_str(encoded) + } +} + +impl FromStr for Id { + type Err = base64::DecodeSliceError; + + fn from_str(s: &str) -> Result { + let mut decoded = [0; 16]; + let bytes_decoded = URL_SAFE_NO_PAD.decode_slice(s.as_bytes(), &mut decoded)?; + if bytes_decoded != 16 { + let err = DecodeError::InvalidLength(bytes_decoded); + return Err(base64::DecodeSliceError::DecodeError(err)); + } + + Ok(Self(u128::from_le_bytes(decoded))) + } +} diff --git a/tower-sesh-core/src/lib.rs b/tower-sesh-core/src/lib.rs new file mode 100644 index 0000000..2481809 --- /dev/null +++ b/tower-sesh-core/src/lib.rs @@ -0,0 +1,15 @@ +//! An abstraction over session storage and retrieval through [`SessionStore`]. +//! +//! Sessions are identified by a unique [`Id`] and can have an [`Expiry`] with the [`Expires`] +//! trait. +#[doc(inline)] +pub use self::session_store::SessionStore; +pub use self::id::Id; +pub use self::expires::{Expires, Expiry}; + +/// A trait for session storage and retrieval. +pub mod session_store; +/// Session expiry configuration. +pub mod expires; +/// Session IDs. +pub mod id; diff --git a/tower-sesh-core/src/session_store.rs b/tower-sesh-core/src/session_store.rs new file mode 100644 index 0000000..755f16a --- /dev/null +++ b/tower-sesh-core/src/session_store.rs @@ -0,0 +1,265 @@ +//! A session backend for managing session state. +//! +//! This crate provides the ability to use custom backends for session +//! management by implementing the [`SessionStore`] trait. This trait defines +//! the necessary operations for creating, saving, loading, and deleting session +//! records. +//! +//! # Implementing a Custom Store +//! +//! Every method on the [`SessionStore`] trait describes precisely how it should be implemented. +//! The words _must_ and _should_ are used to indicate the level of necessity for each method. +//! Implementations _must_ adhere to the requirements of the method, while _should_ indicates a +//! recommended approach. These recommendations can be taken more lightly if the implementation is +//! for internal use only. +//! +//! TODO: List good examples of implementations. +//! +//! # CachingSessionStore +//! +//! The [`CachingSessionStore`] provides a layered caching mechanism with a +//! cache as the frontend and a store as the backend. This can improve read +//! performance by reducing the need to access the backend store for frequently +//! accessed sessions. +use std::{fmt::Debug, future::Future}; + +use either::Either::{self, Left, Right}; +use futures_util::future::try_join; +use futures_util::TryFutureExt; + +use crate::id::Id; + +/// Defines the interface for session management. +/// +/// The [`SessionStore::Error`] associated type should be used to represent hard errors that occur +/// during backend operations. For example, an implementation _must not_ return an error if a saved +/// record expired. See each method for more details. +/// __Reasoning__: The [`SessionStore`] should not be responsible for handling logic errors. +/// Methods on this trait are designed to return meaningful results for the caller to handle. The +/// `Err(...)` case is reserved for hard errors that the caller most likely cannot handle, such as +/// network errors, timeouts, invalid backend state/config, etc. These errors usually come from the +/// backend store directly, such as [`sqlx::Error`], [`redis::RedisError`], etc. +/// +/// Although recommended, it is not required for a `SessionStore` to handle session expiration. It +/// is acceptable behavior for a session to return a record that is expired. The caller should be +/// the one to decide what storage to use, and to use one that handles expiration if needed. +/// +/// For a [`SessionStore`] to be used as a middleware in a [`SessionManagerLayer`], it must also +/// implement the [`Clone`] trait. The store should also be relatively cheap to clone (a few +/// [`Arc`]s are fine). This is because the store is cloned for every request, and the user should +/// be able to clone it inside of a request handler without much overhead. +/// +/// [`sqlx::Error`]: https://docs.rs/sqlx +/// [`redis::RedisError`]: https://docs.rs/redis +// TODO: Remove all `Send` bounds once we have `return_type_notation`: +// https://github.com/rust-lang/rust/issues/109417. +pub trait SessionStore: Send + Sync { + type Error: Send; + + /// Creates a new session in the store with the provided session record. + /// + /// # Implementations + /// + /// In the successful path, Implementations _must_ return a unique ID for the provided record. + /// + /// If the a provided record is already expired, the implementation _must_ not return an error. + /// A correct implementation _should_ instead return a new ID for the record and not insert it + /// into the store, or it should let the backend store handle the expiration immediately and + /// return the new ID. + /// __Reasoning__: Creating a session that is already expired is a logical mistake, not a hard + /// error. The caller should be responsible for handling this case, when it comes time to + /// use the session. + fn create(&mut self, record: &R) -> impl Future> + Send; + + /// Saves the provided session record to the store. + /// + /// This method is intended for updating the state of an existing session. + /// + /// # Implementations + /// + /// In the successful path, implementations _must_ return `bool` indicating whether the + /// session existed and thus was updated, or if it did not exist (or was expired) and was not + /// updated. + /// __Reasoning__: The caller should be aware of whether the session was successfully updated + /// or not. If not, then this case can be handled by the caller trivially, thus it is not a + /// hard error. + /// + /// If the implementation handles expiration, id _should_ update the expiration time on the + /// session record. + fn save( + &mut self, + id: &Id, + record: &R, + ) -> impl Future> + Send; + + /// Save the provided session record to the store, and create a new one if it does not exist. + /// + /// # Implementations + /// + /// In the successful path, implementations _must_ return `Ok(())` if the record was saved or + /// created with the given ID. This method is only exposed in the API for the sake of other + /// implementations relying on generic `SessionStore` implementations (see + /// [`CachingSessionStore`]). End users using `tower-sesh` are not exposed to this method. + /// + /// If the implementation handles expiration, id _should_ update the expiration time on the + /// session record. + /// + /// # Caution + /// + /// Since the caller can potentially create a new session with a chosen ID, this method should + /// only be used by implementations when it is known that a collision will not occur. The caller + /// should not be in charge of setting the `Id`, it is rather a job for the `SessionStore` + /// through the `create` method. + /// + /// This can also accidently increase the lifetime of a session. Suppose a session is loaded + /// successfully from the store, but then expires before changes are saved. Using this method + /// will reinstate the session with the same ID, prolonging its lifetime. + fn save_or_create( + &mut self, + id: &Id, + record: &R, + ) -> impl Future> + Send; + + /// Loads an existing session record from the store using the provided ID. + /// + /// # Implementations + /// + /// If a session with the given ID exists, it is returned as `Some(record)`. If the session + /// does not exist or has been invalidated (i.e., expired), `None` is returned. + /// __Reasoning__: Loading a session that does not exist is not a hard error, and the caller + /// should be responsible for handling this case. + fn load(&mut self, id: &Id) -> impl Future, Self::Error>> + Send; + + /// Deletes a session record from the store using the provided ID. + /// + /// # Implementations + /// + /// If the session exists (and is not expired), an implmementation _must_ remove the session + /// from the store and return `Some` with the associated record. Otherwise, it must return + /// `Ok(None)`. + /// __Reasoning__: Deleting a session that does not exist is not a hard error, and the caller + /// should be responsible for handling this case. + fn delete(&mut self, id: &Id) -> impl Future> + Send; + + /// Update the ID of a session record. + /// + /// # Implementations + /// + /// This method _must_ return `Ok(None)` if the session does not exist (or is expired). + /// It _must_ return `Ok(Some(id))` with the newly assigned id if it does exist. + /// __Reasoning__: Updating the ID of a session that does not exist is not a hard error, and + /// the caller should be responsible for handling this case. + /// + /// ### Note + /// + /// The default implementation uses one `load`, one `create`, and one `delete` operation to + /// update the `Id`. it is __highly recommended__ to implmement it more efficiently whenever possible. + fn cycle_id( + &mut self, + old_id: &Id, + ) -> impl Future, Self::Error>> + Send { + async move { + let record = self.load(old_id).await?; + if let Some(record) = record { + let new_id = self.create(&record).await?; + self.delete(old_id).await?; + Ok(Some(new_id)) + } else { + Ok(None) + } + } + } +} + +/// Provides a layered caching mechanism with a cache as the frontend and a +/// store as the backend. +/// +/// By using a cache, the cost of reads can be greatly reduced as once cached, +/// reads need only interact with the frontend, forgoing the cost of retrieving +/// the session record from the backend. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct CachingSessionStore { + cache: Cache, + store: Store, +} + +impl CachingSessionStore { + /// Create a new `CachingSessionStore`. + pub fn new(cache: Cache, store: Store) -> Self { + Self { cache, store } + } +} + +impl SessionStore for CachingSessionStore +where + R: Send + Sync, + Cache: SessionStore, + Store: SessionStore, +{ + type Error = Either; + + async fn create(&mut self, record: &R) -> Result { + let id = self.store.create(record).await.map_err(Right)?; + self.cache.save_or_create(&id, record).await.map_err(Left)?; + Ok(id) + } + + async fn save(&mut self, id: &Id, record: &R) -> Result { + let store_save_fut = self.store.save(id, record).map_err(Right); + let cache_save_fut = self.cache.save(id, record).map_err(Left); + + let (exists_cache, exists_store) = try_join(cache_save_fut, store_save_fut).await?; + + if !exists_store && exists_cache { + self.cache.delete(id).await.map_err(Left)?; + } + + Ok(exists_store) + } + + async fn save_or_create(&mut self, id: &Id, record: &R) -> Result<(), Self::Error> { + let store_save_fut = self.store.save_or_create(id, record).map_err(Right); + let cache_save_fut = self.cache.save_or_create(id, record).map_err(Left); + + try_join(cache_save_fut, store_save_fut).await?; + + Ok(()) + } + + async fn load(&mut self, id: &Id) -> Result, Self::Error> { + match self.cache.load(id).await { + Ok(Some(session_record)) => Ok(Some(session_record)), + Ok(None) => { + let session_record = self.store.load(id).await.map_err(Right)?; + + if let Some(ref session_record) = session_record { + self.cache + .save(id, session_record) + .await + .map_err(Either::Left)?; + } + + Ok(session_record) + } + Err(err) => Err(Left(err)), + } + } + + async fn delete(&mut self, id: &Id) -> Result { + let store_delete_fut = self.store.delete(id).map_err(Right); + let cache_delete_fut = self.cache.delete(id).map_err(Left); + + let (_, in_store) = try_join(cache_delete_fut, store_delete_fut).await?; + + Ok(in_store) + } + + async fn cycle_id(&mut self, old_id: &Id) -> Result, Self::Error> { + let delete_cache = self.cache.delete(old_id).map_err(Left); + let new_id = self.store.cycle_id(old_id).map_err(Right); + + try_join(delete_cache, new_id) + .await + .map(|(_, new_id)| new_id) + } +} diff --git a/tower-sessions-core/Cargo.toml b/tower-sessions-core/Cargo.toml deleted file mode 100644 index 3516123..0000000 --- a/tower-sessions-core/Cargo.toml +++ /dev/null @@ -1,37 +0,0 @@ -[package] -name = "tower-sessions-core" -description = "Core types and traits for tower-sessions." -documentation.workspace = true -version.workspace = true -license.workspace = true -edition.workspace = true -authors.workspace = true -repository.workspace = true - -[features] -default = [] -axum-core = ["dep:axum-core"] -deletion-task = ["tokio/time"] - -[dependencies] -async-trait = { workspace = true } -axum-core = { version = "0.4", optional = true } -base64 = "0.22.0" -futures = { version = "0.3.28", default-features = false, features = [ - "async-await", -] } -http = "1.0" -parking_lot = { version = "0.12.1", features = ["serde"] } -rand = "0.8.5" -serde = { version = "1.0.189", features = ["derive", "rc"] } -serde_json = "1.0.107" -thiserror = { workspace = true } -time = { version = "0.3.29", features = ["serde"] } -tokio = { workspace = true } -tracing = { version = "0.1.40", features = ["log"] } - -[dev-dependencies] -tower-sessions = { workspace = true, features = ["memory-store"] } -tokio-test = "0.4.3" -tokio = { workspace = true, features = ["rt", "macros"] } -mockall = "0.13.0" diff --git a/tower-sessions-core/src/extract.rs b/tower-sessions-core/src/extract.rs deleted file mode 100644 index ee9d1aa..0000000 --- a/tower-sessions-core/src/extract.rs +++ /dev/null @@ -1,20 +0,0 @@ -use async_trait::async_trait; -use axum_core::extract::FromRequestParts; -use http::{request::Parts, StatusCode}; - -use crate::session::Session; - -#[async_trait] -impl FromRequestParts for Session -where - S: Sync + Send, -{ - type Rejection = (http::StatusCode, &'static str); - - async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { - parts.extensions.get::().cloned().ok_or(( - StatusCode::INTERNAL_SERVER_ERROR, - "Can't extract session. Is `SessionManagerLayer` enabled?", - )) - } -} diff --git a/tower-sessions-core/src/lib.rs b/tower-sessions-core/src/lib.rs deleted file mode 100644 index e775c0a..0000000 --- a/tower-sessions-core/src/lib.rs +++ /dev/null @@ -1,11 +0,0 @@ -#[doc(inline)] -pub use self::{ - session::{Expiry, Session}, - session_store::{CachingSessionStore, ExpiredDeletion, SessionStore}, -}; - -#[cfg(feature = "axum-core")] -#[cfg_attr(docsrs, doc(cfg(feature = "axum-core")))] -pub mod extract; -pub mod session; -pub mod session_store; diff --git a/tower-sessions-core/src/session.rs b/tower-sessions-core/src/session.rs deleted file mode 100644 index c4ed43a..0000000 --- a/tower-sessions-core/src/session.rs +++ /dev/null @@ -1,1052 +0,0 @@ -//! A session which allows HTTP applications to associate data with visitors. -use std::{ - collections::HashMap, - fmt::{self, Display}, - hash::Hash, - result, - str::{self, FromStr}, - sync::{ - atomic::{self, AtomicBool}, - Arc, - }, -}; - -use base64::{engine::general_purpose::URL_SAFE_NO_PAD, DecodeError, Engine as _}; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_json::Value; -use time::{Duration, OffsetDateTime}; -use tokio::sync::{MappedMutexGuard, Mutex, MutexGuard}; - -use crate::{session_store, SessionStore}; - -const DEFAULT_DURATION: Duration = Duration::weeks(2); - -type Result = result::Result; - -type Data = HashMap; - -/// Session errors. -#[derive(thiserror::Error, Debug)] -pub enum Error { - /// Maps `serde_json` errors. - #[error(transparent)] - SerdeJson(#[from] serde_json::Error), - - /// Maps `session_store::Error` errors. - #[error(transparent)] - Store(#[from] session_store::Error), -} - -#[derive(Debug)] -struct Inner { - // This will be `None` when: - // - // 1. We have not been provided a session cookie or have failed to parse it, - // 2. The store has not found the session. - // - // Sync lock, see: https://docs.rs/tokio/latest/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use - session_id: parking_lot::Mutex>, - - // A lazy representation of the session's value, hydrated on a just-in-time basis. A - // `None` value indicates we have not tried to access it yet. After access, it will always - // contain `Some(Record)`. - record: Mutex>, - - // Sync lock, see: https://docs.rs/tokio/latest/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use - expiry: parking_lot::Mutex>, - - is_modified: AtomicBool, -} - -/// A session which allows HTTP applications to associate key-value pairs with -/// visitors. -#[derive(Debug, Clone)] -pub struct Session { - store: Arc, - inner: Arc, -} - -impl Session { - /// Creates a new session with the session ID, store, and expiry. - /// - /// This method is lazy and does not invoke the overhead of talking to the - /// backing store. - /// - /// # Examples - /// - /// ```rust - /// use std::sync::Arc; - /// - /// use tower_sessions::{MemoryStore, Session}; - /// - /// let store = Arc::new(MemoryStore::default()); - /// Session::new(None, store, None); - /// ``` - pub fn new( - session_id: Option, - store: Arc, - expiry: Option, - ) -> Self { - let inner = Inner { - session_id: parking_lot::Mutex::new(session_id), - record: Mutex::new(None), // `None` indicates we have not loaded from store. - expiry: parking_lot::Mutex::new(expiry), - is_modified: AtomicBool::new(false), - }; - - Self { - store, - inner: Arc::new(inner), - } - } - - fn create_record(&self) -> Record { - Record::new(self.expiry_date()) - } - - #[tracing::instrument(skip(self), err)] - async fn get_record(&self) -> Result> { - let mut record_guard = self.inner.record.lock().await; - - // Lazily load the record since `None` here indicates we have no yet loaded it. - if record_guard.is_none() { - tracing::trace!("record not loaded from store; loading"); - - let session_id = *self.inner.session_id.lock(); - *record_guard = Some(if let Some(session_id) = session_id { - match self.store.load(&session_id).await? { - Some(loaded_record) => { - tracing::trace!("record found in store"); - loaded_record - } - - None => { - // A well-behaved user agent should not send session cookies after - // expiration. Even so it's possible for an expired session to be removed - // from the store after a request was initiated. However, such a race should - // be relatively uncommon and as such entering this branch could indicate - // malicious behavior. - tracing::warn!("possibly suspicious activity: record not found in store"); - *self.inner.session_id.lock() = None; - self.create_record() - } - } - } else { - tracing::trace!("session id not found"); - self.create_record() - }) - } - - Ok(MutexGuard::map(record_guard, |opt| { - opt.as_mut() - .expect("Record should always be `Option::Some` at this point") - })) - } - - /// Inserts a `impl Serialize` value into the session. - /// - /// # Examples - /// - /// ```rust - /// # tokio_test::block_on(async { - /// use std::sync::Arc; - /// - /// use tower_sessions::{MemoryStore, Session}; - /// - /// let store = Arc::new(MemoryStore::default()); - /// let session = Session::new(None, store, None); - /// - /// session.insert("foo", 42).await.unwrap(); - /// - /// let value = session.get::("foo").await.unwrap(); - /// assert_eq!(value, Some(42)); - /// # }); - /// ``` - /// - /// # Errors - /// - /// - This method can fail when [`serde_json::to_value`] fails. - /// - If the session has not been hydrated and loading from the store fails, - /// we fail with [`Error::Store`]. - pub async fn insert(&self, key: &str, value: impl Serialize) -> Result<()> { - self.insert_value(key, serde_json::to_value(&value)?) - .await?; - Ok(()) - } - - /// Inserts a `serde_json::Value` into the session. - /// - /// If the key was not present in the underlying map, `None` is returned and - /// `modified` is set to `true`. - /// - /// If the underlying map did have the key and its value is the same as the - /// provided value, `None` is returned and `modified` is not set. - /// - /// # Examples - /// - /// ```rust - /// # tokio_test::block_on(async { - /// use std::sync::Arc; - /// - /// use tower_sessions::{MemoryStore, Session}; - /// - /// let store = Arc::new(MemoryStore::default()); - /// let session = Session::new(None, store, None); - /// - /// let value = session - /// .insert_value("foo", serde_json::json!(42)) - /// .await - /// .unwrap(); - /// assert!(value.is_none()); - /// - /// let value = session - /// .insert_value("foo", serde_json::json!(42)) - /// .await - /// .unwrap(); - /// assert!(value.is_none()); - /// - /// let value = session - /// .insert_value("foo", serde_json::json!("bar")) - /// .await - /// .unwrap(); - /// assert_eq!(value, Some(serde_json::json!(42))); - /// # }); - /// ``` - /// - /// # Errors - /// - /// - If the session has not been hydrated and loading from the store fails, - /// we fail with [`Error::Store`]. - pub async fn insert_value(&self, key: &str, value: Value) -> Result> { - let mut record_guard = self.get_record().await?; - Ok(if record_guard.data.get(key) != Some(&value) { - self.inner - .is_modified - .store(true, atomic::Ordering::Release); - record_guard.data.insert(key.to_string(), value) - } else { - None - }) - } - - /// Gets a value from the store. - /// - /// # Examples - /// - /// ```rust - /// # tokio_test::block_on(async { - /// use std::sync::Arc; - /// - /// use tower_sessions::{MemoryStore, Session}; - /// - /// let store = Arc::new(MemoryStore::default()); - /// let session = Session::new(None, store, None); - /// - /// session.insert("foo", 42).await.unwrap(); - /// - /// let value = session.get::("foo").await.unwrap(); - /// assert_eq!(value, Some(42)); - /// # }); - /// ``` - /// - /// # Errors - /// - /// - This method can fail when [`serde_json::from_value`] fails. - /// - If the session has not been hydrated and loading from the store fails, - /// we fail with [`Error::Store`]. - pub async fn get(&self, key: &str) -> Result> { - Ok(self - .get_value(key) - .await? - .map(serde_json::from_value) - .transpose()?) - } - - /// Gets a `serde_json::Value` from the store. - /// - /// # Examples - /// - /// ```rust - /// # tokio_test::block_on(async { - /// use std::sync::Arc; - /// - /// use tower_sessions::{MemoryStore, Session}; - /// - /// let store = Arc::new(MemoryStore::default()); - /// let session = Session::new(None, store, None); - /// - /// session.insert("foo", 42).await.unwrap(); - /// - /// let value = session.get_value("foo").await.unwrap().unwrap(); - /// assert_eq!(value, serde_json::json!(42)); - /// # }); - /// ``` - /// - /// # Errors - /// - /// - If the session has not been hydrated and loading from the store fails, - /// we fail with [`Error::Store`]. - pub async fn get_value(&self, key: &str) -> Result> { - let record_guard = self.get_record().await?; - Ok(record_guard.data.get(key).cloned()) - } - - /// Removes a value from the store, retuning the value of the key if it was - /// present in the underlying map. - /// - /// # Examples - /// - /// ```rust - /// # tokio_test::block_on(async { - /// use std::sync::Arc; - /// - /// use tower_sessions::{MemoryStore, Session}; - /// - /// let store = Arc::new(MemoryStore::default()); - /// let session = Session::new(None, store, None); - /// - /// session.insert("foo", 42).await.unwrap(); - /// - /// let value: Option = session.remove("foo").await.unwrap(); - /// assert_eq!(value, Some(42)); - /// - /// let value: Option = session.get("foo").await.unwrap(); - /// assert!(value.is_none()); - /// # }); - /// ``` - /// - /// # Errors - /// - /// - This method can fail when [`serde_json::from_value`] fails. - /// - If the session has not been hydrated and loading from the store fails, - /// we fail with [`Error::Store`]. - pub async fn remove(&self, key: &str) -> Result> { - Ok(self - .remove_value(key) - .await? - .map(serde_json::from_value) - .transpose()?) - } - - /// Removes a `serde_json::Value` from the session. - /// - /// # Examples - /// - /// ```rust - /// # tokio_test::block_on(async { - /// use std::sync::Arc; - /// - /// use tower_sessions::{MemoryStore, Session}; - /// - /// let store = Arc::new(MemoryStore::default()); - /// let session = Session::new(None, store, None); - /// - /// session.insert("foo", 42).await.unwrap(); - /// let value = session.remove_value("foo").await.unwrap().unwrap(); - /// assert_eq!(value, serde_json::json!(42)); - /// - /// let value: Option = session.get("foo").await.unwrap(); - /// assert!(value.is_none()); - /// # }); - /// ``` - /// - /// # Errors - /// - /// - If the session has not been hydrated and loading from the store fails, - /// we fail with [`Error::Store`]. - pub async fn remove_value(&self, key: &str) -> Result> { - let mut record_guard = self.get_record().await?; - self.inner - .is_modified - .store(true, atomic::Ordering::Release); - Ok(record_guard.data.remove(key)) - } - - /// Clears the session of all data but does not delete it from the store. - /// - /// # Examples - /// - /// ```rust - /// # tokio_test::block_on(async { - /// use std::sync::Arc; - /// - /// use tower_sessions::{MemoryStore, Session}; - /// - /// let store = Arc::new(MemoryStore::default()); - /// - /// let session = Session::new(None, store.clone(), None); - /// session.insert("foo", 42).await.unwrap(); - /// assert!(!session.is_empty().await); - /// - /// session.save().await.unwrap(); - /// - /// session.clear().await; - /// - /// // Not empty! (We have an ID still.) - /// assert!(!session.is_empty().await); - /// // Data is cleared... - /// assert!(session.get::("foo").await.unwrap().is_none()); - /// - /// // ...data is cleared before loading from the backend... - /// let session = Session::new(session.id(), store.clone(), None); - /// session.clear().await; - /// assert!(session.get::("foo").await.unwrap().is_none()); - /// - /// let session = Session::new(session.id(), store, None); - /// // ...but data is not deleted from the store. - /// assert_eq!(session.get::("foo").await.unwrap(), Some(42)); - /// # }); - /// ``` - pub async fn clear(&self) { - let mut record_guard = self.inner.record.lock().await; - if let Some(record) = record_guard.as_mut() { - record.data.clear(); - } else if let Some(session_id) = *self.inner.session_id.lock() { - let mut new_record = self.create_record(); - new_record.id = session_id; - *record_guard = Some(new_record); - } - - self.inner - .is_modified - .store(true, atomic::Ordering::Release); - } - - /// Returns `true` if there is no session ID and the session is empty. - /// - /// # Examples - /// - /// ```rust - /// # tokio_test::block_on(async { - /// use std::sync::Arc; - /// - /// use tower_sessions::{session::Id, MemoryStore, Session}; - /// - /// let store = Arc::new(MemoryStore::default()); - /// - /// let session = Session::new(None, store.clone(), None); - /// // Empty if we have no ID and record is not loaded. - /// assert!(session.is_empty().await); - /// - /// let session = Session::new(Some(Id::default()), store.clone(), None); - /// // Not empty if we have an ID but no record. (Record is not loaded here.) - /// assert!(!session.is_empty().await); - /// - /// let session = Session::new(Some(Id::default()), store.clone(), None); - /// session.insert("foo", 42).await.unwrap(); - /// // Not empty after inserting. - /// assert!(!session.is_empty().await); - /// session.save().await.unwrap(); - /// // Not empty after saving. - /// assert!(!session.is_empty().await); - /// - /// let session = Session::new(session.id(), store.clone(), None); - /// session.load().await.unwrap(); - /// // Not empty after loading from store... - /// assert!(!session.is_empty().await); - /// // ...and not empty after accessing the session. - /// session.get::("foo").await.unwrap(); - /// assert!(!session.is_empty().await); - /// - /// let session = Session::new(session.id(), store.clone(), None); - /// session.delete().await.unwrap(); - /// // Not empty after deleting from store... - /// assert!(!session.is_empty().await); - /// session.get::("foo").await.unwrap(); - /// // ...but empty after trying to access the deleted session. - /// assert!(session.is_empty().await); - /// - /// let session = Session::new(None, store, None); - /// session.insert("foo", 42).await.unwrap(); - /// session.flush().await.unwrap(); - /// // Empty after flushing. - /// assert!(session.is_empty().await); - /// # }); - /// ``` - pub async fn is_empty(&self) -> bool { - let record_guard = self.inner.record.lock().await; - - // N.B.: Session IDs are `None` if: - // - // 1. The cookie was not provided or otherwise could not be parsed, - // 2. Or the session could not be loaded from the store. - let session_id = self.inner.session_id.lock(); - - let Some(record) = record_guard.as_ref() else { - return session_id.is_none(); - }; - - session_id.is_none() && record.data.is_empty() - } - - /// Get the session ID. - /// - /// # Examples - /// - /// ```rust - /// use std::sync::Arc; - /// - /// use tower_sessions::{session::Id, MemoryStore, Session}; - /// - /// let store = Arc::new(MemoryStore::default()); - /// - /// let session = Session::new(None, store.clone(), None); - /// assert!(session.id().is_none()); - /// - /// let id = Some(Id::default()); - /// let session = Session::new(id, store, None); - /// assert_eq!(id, session.id()); - /// ``` - pub fn id(&self) -> Option { - *self.inner.session_id.lock() - } - - /// Get the session expiry. - /// - /// # Examples - /// - /// ```rust - /// use std::sync::Arc; - /// - /// use tower_sessions::{session::Expiry, MemoryStore, Session}; - /// - /// let store = Arc::new(MemoryStore::default()); - /// let session = Session::new(None, store, None); - /// - /// assert_eq!(session.expiry(), None); - /// ``` - pub fn expiry(&self) -> Option { - *self.inner.expiry.lock() - } - - /// Set `expiry` to the given value. - /// - /// This may be used within applications directly to alter the session's - /// time to live. - /// - /// # Examples - /// - /// ```rust - /// use std::sync::Arc; - /// - /// use time::OffsetDateTime; - /// use tower_sessions::{session::Expiry, MemoryStore, Session}; - /// - /// let store = Arc::new(MemoryStore::default()); - /// let session = Session::new(None, store, None); - /// - /// let expiry = Expiry::AtDateTime(OffsetDateTime::now_utc()); - /// session.set_expiry(Some(expiry)); - /// - /// assert_eq!(session.expiry(), Some(expiry)); - /// ``` - pub fn set_expiry(&self, expiry: Option) { - *self.inner.expiry.lock() = expiry; - self.inner - .is_modified - .store(true, atomic::Ordering::Release); - } - - /// Get session expiry as `OffsetDateTime`. - /// - /// # Examples - /// - /// ```rust - /// use std::sync::Arc; - /// - /// use time::{Duration, OffsetDateTime}; - /// use tower_sessions::{MemoryStore, Session}; - /// - /// let store = Arc::new(MemoryStore::default()); - /// let session = Session::new(None, store, None); - /// - /// // Our default duration is two weeks. - /// let expected_expiry = OffsetDateTime::now_utc().saturating_add(Duration::weeks(2)); - /// - /// assert!(session.expiry_date() > expected_expiry.saturating_sub(Duration::seconds(1))); - /// assert!(session.expiry_date() < expected_expiry.saturating_add(Duration::seconds(1))); - /// ``` - pub fn expiry_date(&self) -> OffsetDateTime { - let expiry = self.inner.expiry.lock(); - match *expiry { - Some(Expiry::OnInactivity(duration)) => { - OffsetDateTime::now_utc().saturating_add(duration) - } - Some(Expiry::AtDateTime(datetime)) => datetime, - Some(Expiry::OnSessionEnd) | None => { - OffsetDateTime::now_utc().saturating_add(DEFAULT_DURATION) // TODO: The default should probably be configurable. - } - } - } - - /// Get session expiry as `Duration`. - /// - /// # Examples - /// - /// ```rust - /// use std::sync::Arc; - /// - /// use time::Duration; - /// use tower_sessions::{MemoryStore, Session}; - /// - /// let store = Arc::new(MemoryStore::default()); - /// let session = Session::new(None, store, None); - /// - /// let expected_duration = Duration::weeks(2); - /// - /// assert!(session.expiry_age() > expected_duration.saturating_sub(Duration::seconds(1))); - /// assert!(session.expiry_age() < expected_duration.saturating_add(Duration::seconds(1))); - /// ``` - pub fn expiry_age(&self) -> Duration { - std::cmp::max( - self.expiry_date() - OffsetDateTime::now_utc(), - Duration::ZERO, - ) - } - - /// Returns `true` if the session has been modified during the request. - /// - /// # Examples - /// - /// ```rust - /// # tokio_test::block_on(async { - /// use std::sync::Arc; - /// - /// use tower_sessions::{MemoryStore, Session}; - /// - /// let store = Arc::new(MemoryStore::default()); - /// let session = Session::new(None, store, None); - /// - /// // Not modified initially. - /// assert!(!session.is_modified()); - /// - /// // Getting doesn't count as a modification. - /// session.get::("foo").await.unwrap(); - /// assert!(!session.is_modified()); - /// - /// // Insertions and removals do though. - /// session.insert("foo", 42).await.unwrap(); - /// assert!(session.is_modified()); - /// # }); - /// ``` - pub fn is_modified(&self) -> bool { - self.inner.is_modified.load(atomic::Ordering::Acquire) - } - - /// Saves the session record to the store. - /// - /// Note that this method is generally not needed and is reserved for - /// situations where the session store must be updated during the - /// request. - /// - /// # Examples - /// - /// ```rust - /// # tokio_test::block_on(async { - /// use std::sync::Arc; - /// - /// use tower_sessions::{MemoryStore, Session}; - /// - /// let store = Arc::new(MemoryStore::default()); - /// let session = Session::new(None, store.clone(), None); - /// - /// session.insert("foo", 42).await.unwrap(); - /// session.save().await.unwrap(); - /// - /// let session = Session::new(session.id(), store, None); - /// assert_eq!(session.get::("foo").await.unwrap().unwrap(), 42); - /// # }); - /// ``` - /// - /// # Errors - /// - /// - If saving to the store fails, we fail with [`Error::Store`]. - #[tracing::instrument(skip(self), err)] - pub async fn save(&self) -> Result<()> { - let mut record_guard = self.get_record().await?; - record_guard.expiry_date = self.expiry_date(); - - // Session ID is `None` if: - // - // 1. No valid cookie was found on the request or, - // 2. No valid session was found in the store. - // - // In either case, we must create a new session via the store interface. - // - // Potential ID collisions must be handled by session store implementers. - if self.inner.session_id.lock().is_none() { - self.store.create(&mut record_guard).await?; - *self.inner.session_id.lock() = Some(record_guard.id); - } else { - self.store.save(&record_guard).await?; - } - Ok(()) - } - - /// Loads the session record from the store. - /// - /// Note that this method is generally not needed and is reserved for - /// situations where the session must be updated during the request. - /// - /// # Examples - /// - /// ```rust - /// # tokio_test::block_on(async { - /// use std::sync::Arc; - /// - /// use tower_sessions::{session::Id, MemoryStore, Session}; - /// - /// let store = Arc::new(MemoryStore::default()); - /// let id = Some(Id::default()); - /// let session = Session::new(id, store.clone(), None); - /// - /// session.insert("foo", 42).await.unwrap(); - /// session.save().await.unwrap(); - /// - /// let session = Session::new(session.id(), store, None); - /// session.load().await.unwrap(); - /// - /// assert_eq!(session.get::("foo").await.unwrap().unwrap(), 42); - /// # }); - /// ``` - /// - /// # Errors - /// - /// - If loading from the store fails, we fail with [`Error::Store`]. - #[tracing::instrument(skip(self), err)] - pub async fn load(&self) -> Result<()> { - let session_id = *self.inner.session_id.lock(); - let Some(ref id) = session_id else { - tracing::warn!("called load with no session id"); - return Ok(()); - }; - let loaded_record = self.store.load(id).await.map_err(Error::Store)?; - let mut record_guard = self.inner.record.lock().await; - *record_guard = loaded_record; - Ok(()) - } - - /// Deletes the session from the store. - /// - /// # Examples - /// - /// ```rust - /// # tokio_test::block_on(async { - /// use std::sync::Arc; - /// - /// use tower_sessions::{session::Id, MemoryStore, Session, SessionStore}; - /// - /// let store = Arc::new(MemoryStore::default()); - /// let session = Session::new(Some(Id::default()), store.clone(), None); - /// - /// // Save before deleting. - /// session.save().await.unwrap(); - /// - /// // Delete from the store. - /// session.delete().await.unwrap(); - /// - /// assert!(store.load(&session.id().unwrap()).await.unwrap().is_none()); - /// # }); - /// ``` - /// - /// # Errors - /// - /// - If deleting from the store fails, we fail with [`Error::Store`]. - #[tracing::instrument(skip(self), err)] - pub async fn delete(&self) -> Result<()> { - let session_id = *self.inner.session_id.lock(); - let Some(ref session_id) = session_id else { - tracing::warn!("called delete with no session id"); - return Ok(()); - }; - self.store.delete(session_id).await.map_err(Error::Store)?; - Ok(()) - } - - /// Flushes the session by removing all data contained in the session and - /// then deleting it from the store. - /// - /// # Examples - /// - /// ```rust - /// # tokio_test::block_on(async { - /// use std::sync::Arc; - /// - /// use tower_sessions::{MemoryStore, Session, SessionStore}; - /// - /// let store = Arc::new(MemoryStore::default()); - /// let session = Session::new(None, store.clone(), None); - /// - /// session.insert("foo", "bar").await.unwrap(); - /// session.save().await.unwrap(); - /// - /// let id = session.id().unwrap(); - /// - /// session.flush().await.unwrap(); - /// - /// assert!(session.id().is_none()); - /// assert!(session.is_empty().await); - /// assert!(store.load(&id).await.unwrap().is_none()); - /// # }); - /// ``` - /// - /// # Errors - /// - /// - If deleting from the store fails, we fail with [`Error::Store`]. - pub async fn flush(&self) -> Result<()> { - self.clear().await; - self.delete().await?; - *self.inner.session_id.lock() = None; - Ok(()) - } - - /// Cycles the session ID while retaining any data that was associated with - /// it. - /// - /// Using this method helps prevent session fixation attacks by ensuring a - /// new ID is assigned to the session. - /// - /// # Examples - /// - /// ```rust - /// # tokio_test::block_on(async { - /// use std::sync::Arc; - /// - /// use tower_sessions::{session::Id, MemoryStore, Session}; - /// - /// let store = Arc::new(MemoryStore::default()); - /// let session = Session::new(None, store.clone(), None); - /// - /// session.insert("foo", 42).await.unwrap(); - /// session.save().await.unwrap(); - /// let id = session.id(); - /// - /// let session = Session::new(session.id(), store.clone(), None); - /// session.cycle_id().await.unwrap(); - /// - /// assert!(!session.is_empty().await); - /// assert!(session.is_modified()); - /// - /// session.save().await.unwrap(); - /// - /// let session = Session::new(session.id(), store, None); - /// - /// assert_ne!(id, session.id()); - /// assert_eq!(session.get::("foo").await.unwrap().unwrap(), 42); - /// # }); - /// ``` - /// - /// # Errors - /// - /// - If deleting from the store fails or saving to the store fails, we fail - /// with [`Error::Store`]. - pub async fn cycle_id(&self) -> Result<()> { - let mut record_guard = self.get_record().await?; - - let old_session_id = record_guard.id; - record_guard.id = Id::default(); - *self.inner.session_id.lock() = None; // Setting `None` ensures `save` invokes the store's - // `create` method. - - self.store - .delete(&old_session_id) - .await - .map_err(Error::Store)?; - - self.inner - .is_modified - .store(true, atomic::Ordering::Release); - - Ok(()) - } -} - -/// ID type for sessions. -/// -/// Wraps an array of 16 bytes. -/// -/// # Examples -/// -/// ```rust -/// use tower_sessions::session::Id; -/// -/// Id::default(); -/// ``` -#[derive(Copy, Clone, Debug, Deserialize, Serialize, Eq, Hash, PartialEq)] -pub struct Id(pub i128); // TODO: By this being public, it may be possible to override the - // session ID, which is undesirable. - -impl Default for Id { - fn default() -> Self { - use rand::prelude::*; - - Self(rand::thread_rng().gen()) - } -} - -impl Display for Id { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut encoded = [0; 22]; - URL_SAFE_NO_PAD - .encode_slice(self.0.to_le_bytes(), &mut encoded) - .expect("Encoded ID must be exactly 22 bytes"); - let encoded = str::from_utf8(&encoded).expect("Encoded ID must be valid UTF-8"); - - f.write_str(encoded) - } -} - -impl FromStr for Id { - type Err = base64::DecodeSliceError; - - fn from_str(s: &str) -> result::Result { - let mut decoded = [0; 16]; - let bytes_decoded = URL_SAFE_NO_PAD.decode_slice(s.as_bytes(), &mut decoded)?; - if bytes_decoded != 16 { - let err = DecodeError::InvalidLength(bytes_decoded); - return Err(base64::DecodeSliceError::DecodeError(err)); - } - - Ok(Self(i128::from_le_bytes(decoded))) - } -} - -/// Record type that's appropriate for encoding and decoding sessions to and -/// from session stores. -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct Record { - pub id: Id, - pub data: Data, - pub expiry_date: OffsetDateTime, -} - -impl Record { - fn new(expiry_date: OffsetDateTime) -> Self { - Self { - id: Id::default(), - data: Data::default(), - expiry_date, - } - } -} - -/// Session expiry configuration. -/// -/// # Examples -/// -/// ```rust -/// use time::{Duration, OffsetDateTime}; -/// use tower_sessions::Expiry; -/// -/// // Will be expired on "session end". -/// let expiry = Expiry::OnSessionEnd; -/// -/// // Will be expired in five minutes from last acitve. -/// let expiry = Expiry::OnInactivity(Duration::minutes(5)); -/// -/// // Will be expired at the given timestamp. -/// let expired_at = OffsetDateTime::now_utc().saturating_add(Duration::weeks(2)); -/// let expiry = Expiry::AtDateTime(expired_at); -/// ``` -#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] -pub enum Expiry { - /// Expire on [current session end][current-session-end], as defined by the - /// browser. - /// - /// [current-session-end]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#removal_defining_the_lifetime_of_a_cookie - OnSessionEnd, - - /// Expire on inactivity. - /// - /// Reading a session is not considered activity for expiration purposes. - /// [`Session`] expiration is computed from the last time the session was - /// _modified_. - OnInactivity(Duration), - - /// Expire at a specific date and time. - /// - /// This value may be extended manually with - /// [`set_expiry`](Session::set_expiry). - AtDateTime(OffsetDateTime), -} - -#[cfg(test)] -mod tests { - use async_trait::async_trait; - use mockall::{ - mock, - predicate::{self, always}, - }; - - use super::*; - - mock! { - #[derive(Debug)] - pub Store {} - - #[async_trait] - impl SessionStore for Store { - async fn create(&self, record: &mut Record) -> session_store::Result<()>; - async fn save(&self, record: &Record) -> session_store::Result<()>; - async fn load(&self, session_id: &Id) -> session_store::Result>; - async fn delete(&self, session_id: &Id) -> session_store::Result<()>; - } - } - - #[tokio::test] - async fn test_cycle_id() { - let mut mock_store = MockStore::new(); - - let initial_id = Id::default(); - let new_id = Id::default(); - - // Set up expectations for the mock store - mock_store - .expect_save() - .with(always()) - .times(1) - .returning(|_| Ok(())); - mock_store - .expect_load() - .with(predicate::eq(initial_id)) - .times(1) - .returning(move |_| { - Ok(Some(Record { - id: initial_id, - data: Data::default(), - expiry_date: OffsetDateTime::now_utc(), - })) - }); - mock_store - .expect_delete() - .with(predicate::eq(initial_id)) - .times(1) - .returning(|_| Ok(())); - mock_store - .expect_create() - .times(1) - .returning(move |record| { - record.id = new_id; - Ok(()) - }); - - let store = Arc::new(mock_store); - let session = Session::new(Some(initial_id), store.clone(), None); - - // Insert some data and save the session - session.insert("foo", 42).await.unwrap(); - session.save().await.unwrap(); - - // Cycle the session ID - session.cycle_id().await.unwrap(); - - // Verify that the session ID has changed and the data is still present - assert_ne!(session.id(), Some(initial_id)); - assert!(session.id().is_none()); // The session ID should be None - assert_eq!(session.get::("foo").await.unwrap(), Some(42)); - - // Save the session to update the ID in the session object - session.save().await.unwrap(); - assert_eq!(session.id(), Some(new_id)); - } -} diff --git a/tower-sessions-core/src/session_store.rs b/tower-sessions-core/src/session_store.rs deleted file mode 100644 index 0a605eb..0000000 --- a/tower-sessions-core/src/session_store.rs +++ /dev/null @@ -1,504 +0,0 @@ -//! A session backend for managing session state. -//! -//! This crate provides the ability to use custom backends for session -//! management by implementing the [`SessionStore`] trait. This trait defines -//! the necessary operations for creating, saving, loading, and deleting session -//! records. -//! -//! # Implementing a Custom Store -//! -//! Below is an example of implementing a custom session store using an -//! in-memory [`HashMap`]. This example is for illustration purposes only; you -//! can use the provided [`MemoryStore`] directly without implementing it -//! yourself. -//! -//! ```rust -//! use std::{collections::HashMap, sync::Arc}; -//! -//! use async_trait::async_trait; -//! use time::OffsetDateTime; -//! use tokio::sync::Mutex; -//! use tower_sessions_core::{ -//! session::{Id, Record}, -//! session_store, SessionStore, -//! }; -//! -//! #[derive(Clone, Debug, Default)] -//! pub struct MemoryStore(Arc>>); -//! -//! #[async_trait] -//! impl SessionStore for MemoryStore { -//! async fn create(&self, record: &mut Record) -> session_store::Result<()> { -//! let mut store_guard = self.0.lock().await; -//! while store_guard.contains_key(&record.id) { -//! // Session ID collision mitigation. -//! record.id = Id::default(); -//! } -//! store_guard.insert(record.id, record.clone()); -//! Ok(()) -//! } -//! -//! async fn save(&self, record: &Record) -> session_store::Result<()> { -//! self.0.lock().await.insert(record.id, record.clone()); -//! Ok(()) -//! } -//! -//! async fn load(&self, session_id: &Id) -> session_store::Result> { -//! Ok(self -//! .0 -//! .lock() -//! .await -//! .get(session_id) -//! .filter(|Record { expiry_date, .. }| is_active(*expiry_date)) -//! .cloned()) -//! } -//! -//! async fn delete(&self, session_id: &Id) -> session_store::Result<()> { -//! self.0.lock().await.remove(session_id); -//! Ok(()) -//! } -//! } -//! -//! fn is_active(expiry_date: OffsetDateTime) -> bool { -//! expiry_date > OffsetDateTime::now_utc() -//! } -//! ``` -//! -//! # Session Store Trait -//! -//! The [`SessionStore`] trait defines the interface for session management. -//! Implementations must handle session creation, saving, loading, and deletion. -//! -//! # CachingSessionStore -//! -//! The [`CachingSessionStore`] provides a layered caching mechanism with a -//! cache as the frontend and a store as the backend. This can improve read -//! performance by reducing the need to access the backend store for frequently -//! accessed sessions. -//! -//! # ExpiredDeletion -//! -//! The [`ExpiredDeletion`] trait provides a method for deleting expired -//! sessions. Implementations can optionally provide a method for continuously -//! deleting expired sessions at a specified interval. -use std::fmt::Debug; - -use async_trait::async_trait; - -use crate::session::{Id, Record}; - -/// Stores must map any errors that might occur during their use to this type. -#[derive(thiserror::Error, Debug)] -pub enum Error { - #[error("Encoding failed with: {0}")] - Encode(String), - - #[error("Decoding failed with: {0}")] - Decode(String), - - #[error("{0}")] - Backend(String), -} - -pub type Result = std::result::Result; - -/// Defines the interface for session management. -/// -/// See [`session_store`](crate::session_store) for more details. -#[async_trait] -pub trait SessionStore: Debug + Send + Sync + 'static { - /// Creates a new session in the store with the provided session record. - /// - /// Implementers must decide how to handle potential ID collisions. For - /// example, they might generate a new unique ID or return `Error::Backend`. - /// - /// The record is given as an exclusive reference to allow modifications, - /// such as assigning a new ID, during the creation process. - async fn create(&self, session_record: &mut Record) -> Result<()> { - default_create(self, session_record).await - } - - /// Saves the provided session record to the store. - /// - /// This method is intended for updating the state of an existing session. - async fn save(&self, session_record: &Record) -> Result<()>; - - /// Loads an existing session record from the store using the provided ID. - /// - /// If a session with the given ID exists, it is returned. If the session - /// does not exist or has been invalidated (e.g., expired), `None` is - /// returned. - async fn load(&self, session_id: &Id) -> Result>; - - /// Deletes a session record from the store using the provided ID. - /// - /// If the session exists, it is removed from the store. - async fn delete(&self, session_id: &Id) -> Result<()>; -} - -async fn default_create( - store: &S, - session_record: &mut Record, -) -> Result<()> { - tracing::warn!( - "The default implementation of `SessionStore::create` is being used, which relies on \ - `SessionStore::save`. To properly handle potential ID collisions, it is recommended that \ - stores implement their own version of `SessionStore::create`." - ); - store.save(session_record).await?; - Ok(()) -} - -/// Provides a layered caching mechanism with a cache as the frontend and a -/// store as the backend.. -/// -/// Contains both a cache, which acts as a frontend, and a store which acts as a -/// backend. Both cache and store implement `SessionStore`. -/// -/// By using a cache, the cost of reads can be greatly reduced as once cached, -/// reads need only interact with the frontend, forgoing the cost of retrieving -/// the session record from the backend. -/// -/// # Examples -/// -/// ```rust,ignore -/// # tokio_test::block_on(async { -/// use tower_sessions::CachingSessionStore; -/// use tower_sessions_moka_store::MokaStore; -/// use tower_sessions_sqlx_store::{SqlitePool, SqliteStore}; -/// let pool = SqlitePool::connect("sqlite::memory:").await.unwrap(); -/// let sqlite_store = SqliteStore::new(pool); -/// let moka_store = MokaStore::new(Some(2_000)); -/// let caching_store = CachingSessionStore::new(moka_store, sqlite_store); -/// # }) -/// ``` -#[derive(Debug, Clone)] -pub struct CachingSessionStore { - cache: Cache, - store: Store, -} - -impl CachingSessionStore { - /// Create a new `CachingSessionStore`. - pub fn new(cache: Cache, store: Store) -> Self { - Self { cache, store } - } -} - -#[async_trait] -impl SessionStore for CachingSessionStore -where - Cache: SessionStore, - Store: SessionStore, -{ - async fn create(&self, record: &mut Record) -> Result<()> { - self.store.create(record).await?; - self.cache.create(record).await?; - Ok(()) - } - - async fn save(&self, record: &Record) -> Result<()> { - let store_save_fut = self.store.save(record); - let cache_save_fut = self.cache.save(record); - - futures::try_join!(store_save_fut, cache_save_fut)?; - - Ok(()) - } - - async fn load(&self, session_id: &Id) -> Result> { - match self.cache.load(session_id).await { - // We found a session in the cache, so let's use it. - Ok(Some(session_record)) => Ok(Some(session_record)), - - // We didn't find a session in the cache, so we'll try loading from the backend. - // - // When we find a session in the backend, we'll hydrate our cache with it. - Ok(None) => { - let session_record = self.store.load(session_id).await?; - - if let Some(ref session_record) = session_record { - self.cache.save(session_record).await?; - } - - Ok(session_record) - } - - // Some error occurred with our cache so we'll bubble this up. - Err(err) => Err(err), - } - } - - async fn delete(&self, session_id: &Id) -> Result<()> { - let store_delete_fut = self.store.delete(session_id); - let cache_delete_fut = self.cache.delete(session_id); - - futures::try_join!(store_delete_fut, cache_delete_fut)?; - - Ok(()) - } -} - -/// Provides a method for deleting expired sessions. -#[async_trait] -pub trait ExpiredDeletion: SessionStore -where - Self: Sized, -{ - /// A method for deleting expired sessions from the store. - async fn delete_expired(&self) -> Result<()>; - - /// This function will keep running indefinitely, deleting expired rows and - /// then waiting for the specified period before deleting again. - /// - /// Generally this will be used as a task, for example via - /// `tokio::task::spawn`. - /// - /// # Errors - /// - /// This function returns a `Result` that contains an error of type - /// `sqlx::Error` if the deletion operation fails. - /// - /// # Examples - /// - /// ```rust,no_run,ignore - /// use tower_sessions::session_store::ExpiredDeletion; - /// use tower_sessions_sqlx_store::{sqlx::SqlitePool, SqliteStore}; - /// - /// # { - /// # tokio_test::block_on(async { - /// let pool = SqlitePool::connect("sqlite::memory:").await.unwrap(); - /// let session_store = SqliteStore::new(pool); - /// - /// tokio::task::spawn( - /// session_store - /// .clone() - /// .continuously_delete_expired(tokio::time::Duration::from_secs(60)), - /// ); - /// # }) - /// ``` - #[cfg(feature = "deletion-task")] - #[cfg_attr(docsrs, doc(cfg(feature = "deletion-task")))] - async fn continuously_delete_expired(self, period: tokio::time::Duration) -> Result<()> { - let mut interval = tokio::time::interval(period); - interval.tick().await; // The first tick completes immediately; skip. - loop { - interval.tick().await; - self.delete_expired().await?; - } - } -} - -#[cfg(test)] -mod tests { - use mockall::{ - mock, - predicate::{self, *}, - }; - use time::{Duration, OffsetDateTime}; - - use super::*; - - mock! { - #[derive(Debug)] - pub Cache {} - - #[async_trait] - impl SessionStore for Cache { - async fn create(&self, record: &mut Record) -> Result<()>; - async fn save(&self, record: &Record) -> Result<()>; - async fn load(&self, session_id: &Id) -> Result>; - async fn delete(&self, session_id: &Id) -> Result<()>; - } - } - - mock! { - #[derive(Debug)] - pub Store {} - - #[async_trait] - impl SessionStore for Store { - async fn create(&self, record: &mut Record) -> Result<()>; - async fn save(&self, record: &Record) -> Result<()>; - async fn load(&self, session_id: &Id) -> Result>; - async fn delete(&self, session_id: &Id) -> Result<()>; - } - } - - mock! { - #[derive(Debug)] - pub CollidingStore {} - - #[async_trait] - impl SessionStore for CollidingStore { - async fn save(&self, record: &Record) -> Result<()>; - async fn load(&self, session_id: &Id) -> Result>; - async fn delete(&self, session_id: &Id) -> Result<()>; - } - } - - #[tokio::test] - async fn test_create() { - let mut store = MockCollidingStore::new(); - let mut record = Record { - id: Default::default(), - data: Default::default(), - expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30), - }; - - store - .expect_save() - .with(predicate::eq(record.clone())) - .times(1) - .returning(|_| Ok(())); - let result = store.create(&mut record).await; - assert!(result.is_ok()); - } - - #[tokio::test] - async fn test_save() { - let mut store = MockStore::new(); - let record = Record { - id: Default::default(), - data: Default::default(), - expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30), - }; - store - .expect_save() - .with(predicate::eq(record.clone())) - .times(1) - .returning(|_| Ok(())); - - let result = store.save(&record).await; - assert!(result.is_ok()); - } - - #[tokio::test] - async fn test_load() { - let mut store = MockStore::new(); - let session_id = Id::default(); - let record = Record { - id: Default::default(), - data: Default::default(), - expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30), - }; - let expected_record = record.clone(); - - store - .expect_load() - .with(predicate::eq(session_id)) - .times(1) - .returning(move |_| Ok(Some(record.clone()))); - - let result = store.load(&session_id).await; - assert!(result.is_ok()); - assert_eq!(result.unwrap(), Some(expected_record)); - } - - #[tokio::test] - async fn test_delete() { - let mut store = MockStore::new(); - let session_id = Id::default(); - - store - .expect_delete() - .with(predicate::eq(session_id)) - .times(1) - .returning(|_| Ok(())); - - let result = store.delete(&session_id).await; - assert!(result.is_ok()); - } - - #[tokio::test] - async fn test_caching_store_create() { - let mut cache = MockCache::new(); - let mut store = MockStore::new(); - let mut record = Record { - id: Default::default(), - data: Default::default(), - expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30), - }; - - cache.expect_create().times(1).returning(|_| Ok(())); - store.expect_create().times(1).returning(|_| Ok(())); - - let caching_store = CachingSessionStore::new(cache, store); - let result = caching_store.create(&mut record).await; - assert!(result.is_ok()); - } - - #[tokio::test] - async fn test_caching_store_save() { - let mut cache = MockCache::new(); - let mut store = MockStore::new(); - let record = Record { - id: Default::default(), - data: Default::default(), - expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30), - }; - - cache - .expect_save() - .with(predicate::eq(record.clone())) - .times(1) - .returning(|_| Ok(())); - store - .expect_save() - .with(predicate::eq(record.clone())) - .times(1) - .returning(|_| Ok(())); - - let caching_store = CachingSessionStore::new(cache, store); - let result = caching_store.save(&record).await; - assert!(result.is_ok()); - } - - #[tokio::test] - async fn test_caching_store_load() { - let mut cache = MockCache::new(); - let mut store = MockStore::new(); - let session_id = Id::default(); - let record = Record { - id: Default::default(), - data: Default::default(), - expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30), - }; - let expected_record = record.clone(); - - cache - .expect_load() - .with(predicate::eq(session_id)) - .times(1) - .returning(move |_| Ok(Some(record.clone()))); - // Store load should not be called since cache returns a record - store.expect_load().times(0); - - let caching_store = CachingSessionStore::new(cache, store); - let result = caching_store.load(&session_id).await; - assert!(result.is_ok()); - assert_eq!(result.unwrap(), Some(expected_record)); - } - - #[tokio::test] - async fn test_caching_store_delete() { - let mut cache = MockCache::new(); - let mut store = MockStore::new(); - let session_id = Id::default(); - - cache - .expect_delete() - .with(predicate::eq(session_id)) - .times(1) - .returning(|_| Ok(())); - store - .expect_delete() - .with(predicate::eq(session_id)) - .times(1) - .returning(|_| Ok(())); - - let caching_store = CachingSessionStore::new(cache, store); - let result = caching_store.delete(&session_id).await; - assert!(result.is_ok()); - } -}