From 7876f57da8b850fae4e11dc85f0b953410e07ab2 Mon Sep 17 00:00:00 2001 From: Eric Date: Mon, 25 Dec 2023 11:33:50 +0800 Subject: [PATCH] feat: support relay-style pagination, apply on uses & invitations (#1106) * feat: support relay-style pagination, apply on uses & invitations * resolve comment --- Cargo.lock | 16 +- crates/juniper-axum/src/lib.rs | 1 + crates/juniper-axum/src/relay/connection.rs | 155 +++++++++++ crates/juniper-axum/src/relay/edge.rs | 109 ++++++++ crates/juniper-axum/src/relay/mod.rs | 115 ++++++++ crates/juniper-axum/src/relay/node_type.rs | 22 ++ crates/juniper-axum/src/relay/page_info.rs | 10 + ee/tabby-webserver/graphql/schema.graphql | 63 ++++- ee/tabby-webserver/src/schema/auth.rs | 55 ++++ ee/tabby-webserver/src/schema/mod.rs | 96 ++++++- ee/tabby-webserver/src/service/auth.rs | 52 +++- .../src/service/db/invitations.rs | 26 ++ ee/tabby-webserver/src/service/db/mod.rs | 34 +++ ee/tabby-webserver/src/service/db/users.rs | 260 +++++++++++++++++- 14 files changed, 984 insertions(+), 30 deletions(-) create mode 100644 crates/juniper-axum/src/relay/connection.rs create mode 100644 crates/juniper-axum/src/relay/edge.rs create mode 100644 crates/juniper-axum/src/relay/mod.rs create mode 100644 crates/juniper-axum/src/relay/node_type.rs create mode 100644 crates/juniper-axum/src/relay/page_info.rs diff --git a/Cargo.lock b/Cargo.lock index ac879cdf0ec1..14459e3df9a9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -230,7 +230,7 @@ checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf" dependencies = [ "async-trait", "axum-core", - "base64 0.21.2", + "base64 0.21.5", "bitflags 1.3.2", "bytes", "futures-util", @@ -357,9 +357,9 @@ checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" [[package]] name = "base64" -version = "0.21.2" +version = "0.21.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" +checksum = "35636a1494ede3b646cc98f74f8e62c773a38a659ebc777a2cf26b9b74171df9" [[package]] name = "base64ct" @@ -1761,7 +1761,7 @@ version = "9.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "155c4d7e39ad04c172c5e3a99c434ea3b4a7ba7960b38ecd562b270b097cce09" dependencies = [ - "base64 0.21.2", + "base64 0.21.5", "pem", "ring", "serde", @@ -2063,7 +2063,7 @@ version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a4964177ddfdab1e3a2b37aec7cf320e14169abb0ed73999f558136409178d5" dependencies = [ - "base64 0.21.2", + "base64 0.21.5", "hyper", "indexmap 1.9.3", "ipnet", @@ -2535,7 +2535,7 @@ version = "3.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3163d2912b7c3b52d651a055f2c7eec9ba5cd22d26ef75b8dd3a59980b185923" dependencies = [ - "base64 0.21.2", + "base64 0.21.5", "serde", ] @@ -2966,7 +2966,7 @@ version = "0.11.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "046cd98826c46c2ac8ddecae268eb5c2e58628688a5fc7a2643704a73faba95b" dependencies = [ - "base64 0.21.2", + "base64 0.21.5", "bytes", "encoding_rs", "futures-core", @@ -3831,7 +3831,7 @@ dependencies = [ "aho-corasick", "arc-swap", "async-trait", - "base64 0.21.2", + "base64 0.21.5", "bitpacking", "byteorder", "census", diff --git a/crates/juniper-axum/src/lib.rs b/crates/juniper-axum/src/lib.rs index 092f46f873d0..15ae6380f785 100644 --- a/crates/juniper-axum/src/lib.rs +++ b/crates/juniper-axum/src/lib.rs @@ -1,4 +1,5 @@ pub mod extract; +pub mod relay; pub mod response; use std::future; diff --git a/crates/juniper-axum/src/relay/connection.rs b/crates/juniper-axum/src/relay/connection.rs new file mode 100644 index 000000000000..53dbba4fb316 --- /dev/null +++ b/crates/juniper-axum/src/relay/connection.rs @@ -0,0 +1,155 @@ +use juniper::{ + marker::IsOutputType, meta::MetaType, Arguments, ExecutionResult, Executor, GraphQLType, + GraphQLValue, GraphQLValueAsync, Registry, ScalarValue, +}; + +use crate::relay::{edge::Edge, page_info::PageInfo, NodeType}; + +/// Connection type +/// +/// Connection is the result of a query for `relay::query` or `relay::query_async` +pub struct Connection { + /// All edges of the current page. + pub edges: Vec>, + pub page_info: PageInfo, +} + +impl Connection +where + Node: NodeType, +{ + /// Returns a relay relay with no elements. + pub fn empty() -> Self { + Self { + edges: Vec::new(), + page_info: PageInfo::default(), + } + } + + pub fn build_connection(nodes: Vec, first: Option, last: Option) -> Self { + let has_next_page = first.map_or(false, |first| nodes.len() > first); + let has_previous_page = last.map_or(false, |last| nodes.len() > last); + let len = nodes.len(); + + let edges: Vec<_> = if let Some(last) = last { + nodes + .into_iter() + .rev() + .take(last) + .rev() + .map(|node| { + let cursor = node.cursor(); + Edge::new(cursor.to_string(), node) + }) + .collect() + } else { + nodes + .into_iter() + .take(first.unwrap_or(len)) + .map(|node| { + let cursor = node.cursor(); + Edge::new(cursor.to_string(), node) + }) + .collect() + }; + + Connection { + page_info: PageInfo { + has_previous_page, + has_next_page, + start_cursor: edges.first().map(|edge| edge.cursor.clone()), + end_cursor: edges.last().map(|edge| edge.cursor.clone()), + }, + edges, + } + } +} + +impl GraphQLType for Connection +where + Node: NodeType + GraphQLType, + Node::Context: juniper::Context, + S: ScalarValue, +{ + fn name(_info: &Self::TypeInfo) -> Option<&str> { + Some(Node::connection_type_name()) + } + + fn meta<'r>(info: &Self::TypeInfo, registry: &mut Registry<'r, S>) -> MetaType<'r, S> + where + S: 'r, + { + let fields = [ + registry.field::<&Vec>>("edges", info), + registry.field::<&PageInfo>("pageInfo", &()), + ]; + registry + .build_object_type::(info, &fields) + .into_meta() + } +} + +impl GraphQLValue for Connection +where + Node: NodeType + GraphQLType, + Node::Context: juniper::Context, + S: ScalarValue, +{ + type Context = Node::Context; + type TypeInfo = >::TypeInfo; + + fn type_name<'i>(&self, info: &'i Self::TypeInfo) -> Option<&'i str> { + >::name(info) + } + + fn resolve_field( + &self, + info: &Self::TypeInfo, + field_name: &str, + _arguments: &Arguments, + executor: &Executor, + ) -> ExecutionResult { + match field_name { + "edges" => executor.resolve_with_ctx(info, &self.edges), + "pageInfo" => executor.resolve_with_ctx(&(), &self.page_info), + _ => panic!("Field {} not found on type ConnectionEdge", field_name), + } + } + + fn concrete_type_name(&self, _context: &Self::Context, info: &Self::TypeInfo) -> String { + self.type_name(info).unwrap_or("Connection").to_string() + } +} + +impl GraphQLValueAsync for Connection +where + Node: NodeType + GraphQLType + GraphQLValueAsync + Send + Sync, + Node::TypeInfo: Sync, + Node::Context: juniper::Context + Sync, + S: ScalarValue + Send + Sync, +{ + fn resolve_field_async<'a>( + &'a self, + info: &'a Self::TypeInfo, + field_name: &'a str, + _arguments: &'a Arguments, + executor: &'a Executor, + ) -> juniper::BoxFuture<'a, ExecutionResult> { + let f = async move { + match field_name { + "edges" => executor.resolve_with_ctx_async(info, &self.edges).await, + "pageInfo" => executor.resolve_with_ctx(&(), &self.page_info), + _ => panic!("Field {} not found on type ConnectionEdge", field_name), + } + }; + use ::juniper::futures::future; + future::FutureExt::boxed(f) + } +} + +impl IsOutputType for Connection +where + Node: GraphQLType, + S: ScalarValue, +{ +} diff --git a/crates/juniper-axum/src/relay/edge.rs b/crates/juniper-axum/src/relay/edge.rs new file mode 100644 index 000000000000..b8593ee5c546 --- /dev/null +++ b/crates/juniper-axum/src/relay/edge.rs @@ -0,0 +1,109 @@ +use juniper::{ + marker::IsOutputType, meta::MetaType, Arguments, BoxFuture, ExecutionResult, Executor, + GraphQLType, GraphQLValue, GraphQLValueAsync, Registry, ScalarValue, +}; + +use crate::relay::NodeType; + +/// An edge in a relay. +pub struct Edge { + pub cursor: String, + pub node: Node, +} + +impl Edge { + /// Create a new edge. + #[inline] + pub fn new(cursor: String, node: Node) -> Self { + Self { cursor, node } + } +} + +impl GraphQLType for Edge +where + Node: NodeType + GraphQLType, + Node::Context: juniper::Context, + S: ScalarValue, +{ + fn name(_info: &Self::TypeInfo) -> Option<&str> { + Some(Node::edge_type_name()) + } + + fn meta<'r>(info: &Self::TypeInfo, registry: &mut Registry<'r, S>) -> MetaType<'r, S> + where + S: 'r, + { + let fields = [ + registry.field::<&Node>("node", info), + registry.field::<&String>("cursor", &()), + ]; + registry + .build_object_type::(info, &fields) + .into_meta() + } +} + +impl GraphQLValue for Edge +where + Node: NodeType + GraphQLType, + Node::Context: juniper::Context, + S: ScalarValue, +{ + type Context = Node::Context; + type TypeInfo = >::TypeInfo; + + fn type_name<'i>(&self, info: &'i Self::TypeInfo) -> Option<&'i str> { + >::name(info) + } + + fn resolve_field( + &self, + info: &Self::TypeInfo, + field_name: &str, + _arguments: &Arguments, + executor: &Executor, + ) -> ExecutionResult { + match field_name { + "node" => executor.resolve_with_ctx(info, &self.node), + "cursor" => executor.resolve_with_ctx(&(), &self.cursor), + _ => panic!("Field {} not found on type ConnectionEdge", field_name), + } + } + + fn concrete_type_name(&self, _context: &Self::Context, info: &Self::TypeInfo) -> String { + self.type_name(info).unwrap_or("ConnectionEdge").to_string() + } +} + +impl GraphQLValueAsync for Edge +where + Node: NodeType + GraphQLType + GraphQLValueAsync + Send + Sync, + Node::TypeInfo: Sync, + Node::Context: juniper::Context + Sync, + S: ScalarValue + Send + Sync, +{ + fn resolve_field_async<'a>( + &'a self, + info: &'a Self::TypeInfo, + field_name: &'a str, + _arguments: &'a Arguments, + executor: &'a Executor, + ) -> BoxFuture<'a, ExecutionResult> { + let f = async move { + match field_name { + "node" => executor.resolve_with_ctx_async(info, &self.node).await, + "cursor" => executor.resolve_with_ctx(&(), &self.cursor), + _ => panic!("Field {} not found on type RelayConnectionEdge", field_name), + } + }; + use ::juniper::futures::future; + future::FutureExt::boxed(f) + } +} + +impl IsOutputType for Edge +where + Node: GraphQLType, + S: ScalarValue, +{ +} diff --git a/crates/juniper-axum/src/relay/mod.rs b/crates/juniper-axum/src/relay/mod.rs new file mode 100644 index 000000000000..7aaf02ab730b --- /dev/null +++ b/crates/juniper-axum/src/relay/mod.rs @@ -0,0 +1,115 @@ +use std::future::Future; + +use juniper::FieldResult; + +mod connection; +mod edge; +mod node_type; +mod page_info; + +pub use connection::Connection; +pub use edge::Edge; +pub use node_type::NodeType; +pub use page_info::PageInfo; + +pub fn query( + after: Option, + before: Option, + first: Option, + last: Option, + f: F, +) -> FieldResult> +where + Node: NodeType + Sync, + F: FnOnce( + Option, + Option, + Option, + Option, + ) -> FieldResult>, +{ + if first.is_some() && last.is_some() { + return Err("The \"first\" and \"last\" parameters cannot exist at the same time".into()); + } + + let first = match first { + Some(first) if first < 0 => { + return Err("The \"first\" parameter must be a non-negative number".into()); + } + Some(first) => Some(first as usize), + None => None, + }; + + let last = match last { + Some(last) if last < 0 => { + return Err("The \"last\" parameter must be a non-negative number".into()); + } + Some(last) => Some(last as usize), + None => None, + }; + + match (first, last) { + (None, None) => { + let nodes = f(after, before, None, None)?; + Ok(Connection::build_connection(nodes, None, None)) + } + (Some(first), None) => { + let nodes = f(after, before, Some(first + 1), None)?; + Ok(Connection::build_connection(nodes, Some(first), None)) + } + (None, Some(last)) => { + let nodes = f(after, before, None, Some(last + 1))?; + Ok(Connection::build_connection(nodes, None, Some(last))) + } + _ => Err("The \"first\" and \"last\" parameters cannot exist at the same time".into()), + } +} + +pub async fn query_async( + after: Option, + before: Option, + first: Option, + last: Option, + f: F, +) -> FieldResult> +where + Node: NodeType + Sync, + F: FnOnce(Option, Option, Option, Option) -> R, + R: Future>>, +{ + if first.is_some() && last.is_some() { + return Err("The \"first\" and \"last\" parameters cannot exist at the same time".into()); + } + + let first = match first { + Some(first) if first < 0 => { + return Err("The \"first\" parameter must be a non-negative number".into()); + } + Some(first) => Some(first as usize), + None => None, + }; + + let last = match last { + Some(last) if last < 0 => { + return Err("The \"last\" parameter must be a non-negative number".into()); + } + Some(last) => Some(last as usize), + None => None, + }; + + match (first, last) { + (None, None) => { + let nodes = f(after, before, None, None).await?; + Ok(Connection::build_connection(nodes, None, None)) + } + (Some(first), None) => { + let nodes = f(after, before, Some(first + 1), None).await?; + Ok(Connection::build_connection(nodes, Some(first), None)) + } + (None, Some(last)) => { + let nodes = f(after, before, None, Some(last + 1)).await?; + Ok(Connection::build_connection(nodes, None, Some(last))) + } + _ => Err("The \"first\" and \"last\" parameters cannot exist at the same time".into()), + } +} diff --git a/crates/juniper-axum/src/relay/node_type.rs b/crates/juniper-axum/src/relay/node_type.rs new file mode 100644 index 000000000000..6d767652db95 --- /dev/null +++ b/crates/juniper-axum/src/relay/node_type.rs @@ -0,0 +1,22 @@ +use std::str::FromStr; + +pub trait NodeType { + /// The [cursor][spec] type that is used for pagination. A cursor + /// should uniquely identify a given node. + /// + /// [spec]: https://relay.dev/graphql/connections.htm#sec-Cursor + type Cursor: ToString + FromStr + Clone; + + /// Returns the cursor associated with this node. + fn cursor(&self) -> Self::Cursor; + + /// Returns the type name connections + /// over these nodes should have in the + /// API. E.g. `"FooConnection"`. + fn connection_type_name() -> &'static str; + + /// Returns the type name edges containing + /// these nodes should have in the API. + /// E.g. `"FooConnectionEdge"`. + fn edge_type_name() -> &'static str; +} diff --git a/crates/juniper-axum/src/relay/page_info.rs b/crates/juniper-axum/src/relay/page_info.rs new file mode 100644 index 000000000000..643d783c2f3b --- /dev/null +++ b/crates/juniper-axum/src/relay/page_info.rs @@ -0,0 +1,10 @@ +use juniper::GraphQLObject; + +#[derive(Default, GraphQLObject)] +#[graphql(name = "PageInfo")] +pub struct PageInfo { + pub has_previous_page: bool, + pub has_next_page: bool, + pub start_cursor: Option, + pub end_cursor: Option, +} diff --git a/ee/tabby-webserver/graphql/schema.graphql b/ee/tabby-webserver/graphql/schema.graphql index 5c41e3594f00..64ecc8c1cd7d 100644 --- a/ee/tabby-webserver/graphql/schema.graphql +++ b/ee/tabby-webserver/graphql/schema.graphql @@ -1,8 +1,3 @@ -type RegisterResponse { - accessToken: String! - refreshToken: String! -} - enum WorkerKind { COMPLETION CHAT @@ -22,10 +17,6 @@ type Mutation { "DateTime" scalar DateTimeUtc -type VerifyTokenResponse { - claims: JWTPayload! -} - type JWTPayload { "Expiration time (as UTC timestamp)" exp: Float! @@ -44,6 +35,40 @@ type Query { invitations: [Invitation!]! me: User! users: [User!]! + usersNext(after: String, before: String, first: Int, last: Int): UserConnection! + invitationsNext(after: String, before: String, first: Int, last: Int): InvitationConnection! +} + +type UserEdge { + node: User! + cursor: String! +} + +type RefreshTokenResponse { + accessToken: String! + refreshToken: String! + refreshExpiresAt: DateTimeUtc! +} + +type RegisterResponse { + accessToken: String! + refreshToken: String! +} + +type InvitationNext { + id: ID! + email: String! + code: String! + createdAt: String! +} + +type UserConnection { + edges: [UserEdge!]! + pageInfo: PageInfo! +} + +type VerifyTokenResponse { + claims: JWTPayload! } type Invitation { @@ -54,6 +79,7 @@ type Invitation { } type User { + id: ID! email: String! isAdmin: Boolean! authToken: String! @@ -71,15 +97,26 @@ type Worker { cudaDevices: [String!]! } +type InvitationEdge { + node: InvitationNext! + cursor: String! +} + type TokenAuthResponse { accessToken: String! refreshToken: String! } -type RefreshTokenResponse { - accessToken: String! - refreshToken: String! - refreshExpiresAt: DateTimeUtc! +type PageInfo { + hasPreviousPage: Boolean! + hasNextPage: Boolean! + startCursor: String + endCursor: String +} + +type InvitationConnection { + edges: [InvitationEdge!]! + pageInfo: PageInfo! } schema { diff --git a/ee/tabby-webserver/src/schema/auth.rs b/ee/tabby-webserver/src/schema/auth.rs index a3e545649895..9709700345cc 100644 --- a/ee/tabby-webserver/src/schema/auth.rs +++ b/ee/tabby-webserver/src/schema/auth.rs @@ -5,6 +5,7 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; use jsonwebtoken as jwt; use juniper::{FieldError, GraphQLObject, IntoFieldError, ScalarValue}; +use juniper_axum::relay; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; use thiserror::Error; @@ -13,6 +14,7 @@ use uuid::Uuid; use validator::ValidationErrors; use super::{from_validation_errors, User}; +use crate::schema::Context; lazy_static! { static ref JWT_TOKEN_SECRET: String = jwt_token_secret(); @@ -244,6 +246,43 @@ pub struct Invitation { pub created_at: String, } +#[derive(Debug, Serialize, Deserialize, GraphQLObject)] +#[graphql(context = Context)] +pub struct InvitationNext { + pub id: juniper::ID, + pub email: String, + pub code: String, + + pub created_at: String, +} + +impl relay::NodeType for InvitationNext { + type Cursor = String; + + fn cursor(&self) -> Self::Cursor { + self.id.to_string() + } + + fn connection_type_name() -> &'static str { + "InvitationConnection" + } + + fn edge_type_name() -> &'static str { + "InvitationEdge" + } +} + +impl From for InvitationNext { + fn from(val: Invitation) -> Self { + Self { + id: juniper::ID::new(val.id.to_string()), + email: val.email, + code: val.code, + created_at: val.created_at, + } + } +} + #[async_trait] pub trait AuthenticationService: Send + Sync { async fn register( @@ -275,6 +314,22 @@ pub trait AuthenticationService: Send + Sync { async fn reset_user_auth_token(&self, email: &str) -> Result<()>; async fn list_users(&self) -> Result>; + + async fn list_users_in_page( + &self, + after: Option, + before: Option, + first: Option, + last: Option, + ) -> Result>; + + async fn list_invitations_in_page( + &self, + after: Option, + before: Option, + first: Option, + last: Option, + ) -> Result>; } #[cfg(test)] diff --git a/ee/tabby-webserver/src/schema/mod.rs b/ee/tabby-webserver/src/schema/mod.rs index f7f3a298952f..a6fdd0b1ffd0 100644 --- a/ee/tabby-webserver/src/schema/mod.rs +++ b/ee/tabby-webserver/src/schema/mod.rs @@ -6,10 +6,10 @@ use std::sync::Arc; use auth::AuthenticationService; use chrono::{DateTime, Utc}; use juniper::{ - graphql_object, graphql_value, EmptySubscription, FieldError, GraphQLObject, IntoFieldError, - Object, RootNode, ScalarValue, Value, + graphql_object, graphql_value, EmptySubscription, FieldError, FieldResult, GraphQLObject, + IntoFieldError, Object, RootNode, ScalarValue, Value, }; -use juniper_axum::FromAuth; +use juniper_axum::{relay, FromAuth}; use tabby_common::api::{code::CodeSearch, event::RawEventLogger}; use validator::ValidationErrors; @@ -19,8 +19,8 @@ use self::{ }; use crate::schema::{ auth::{ - RefreshTokenError, RefreshTokenResponse, RegisterResponse, TokenAuthResponse, - VerifyTokenResponse, + InvitationNext, RefreshTokenError, RefreshTokenResponse, RegisterResponse, + TokenAuthResponse, VerifyTokenResponse, }, worker::Worker, }; @@ -128,16 +128,102 @@ impl Query { } Err(CoreError::Unauthorized("Only admin is able to query users")) } + + async fn usersNext( + ctx: &Context, + after: Option, + before: Option, + first: Option, + last: Option, + ) -> FieldResult> { + if let Some(claims) = &ctx.claims { + if claims.is_admin { + return relay::query_async( + after, + before, + first, + last, + |after, before, first, last| async move { + match ctx + .locator + .auth() + .list_users_in_page(after, before, first, last) + .await + { + Ok(users) => Ok(users), + Err(err) => Err(FieldError::from(err)), + } + }, + ) + .await; + } + } + Err(FieldError::from(CoreError::Unauthorized( + "Only admin is able to query users", + ))) + } + + async fn invitationsNext( + ctx: &Context, + after: Option, + before: Option, + first: Option, + last: Option, + ) -> FieldResult> { + if let Some(claims) = &ctx.claims { + if claims.is_admin { + return relay::query_async( + after, + before, + first, + last, + |after, before, first, last| async move { + match ctx + .locator + .auth() + .list_invitations_in_page(after, before, first, last) + .await + { + Ok(invitations) => Ok(invitations), + Err(err) => Err(FieldError::from(err)), + } + }, + ) + .await; + } + } + Err(FieldError::from(CoreError::Unauthorized( + "Only admin is able to query users", + ))) + } } #[derive(Debug, GraphQLObject)] +#[graphql(context = Context)] pub struct User { + pub id: juniper::ID, pub email: String, pub is_admin: bool, pub auth_token: String, pub created_at: DateTime, } +impl relay::NodeType for User { + type Cursor = String; + + fn cursor(&self) -> Self::Cursor { + self.id.to_string() + } + + fn connection_type_name() -> &'static str { + "UserConnection" + } + + fn edge_type_name() -> &'static str { + "UserEdge" + } +} + #[derive(Default)] pub struct Mutation; diff --git a/ee/tabby-webserver/src/service/auth.rs b/ee/tabby-webserver/src/service/auth.rs index 9b2d9d57e00d..6a12a94ffbb9 100644 --- a/ee/tabby-webserver/src/service/auth.rs +++ b/ee/tabby-webserver/src/service/auth.rs @@ -13,8 +13,8 @@ use super::db::DbConn; use crate::schema::{ auth::{ generate_jwt, generate_refresh_token, validate_jwt, AuthenticationService, Invitation, - JWTPayload, RefreshTokenError, RefreshTokenResponse, RegisterError, RegisterResponse, - TokenAuthError, TokenAuthResponse, VerifyTokenResponse, + InvitationNext, JWTPayload, RefreshTokenError, RefreshTokenResponse, RegisterError, + RegisterResponse, TokenAuthError, TokenAuthResponse, VerifyTokenResponse, }, User, }; @@ -301,6 +301,54 @@ impl AuthenticationService for DbConn { let users = self.list_users().await?; Ok(users.into_iter().map(|x| x.into()).collect()) } + + async fn list_users_in_page( + &self, + after: Option, + before: Option, + first: Option, + last: Option, + ) -> Result> { + let users = match (first, last) { + (Some(first), None) => { + let after = after.map(|x| x.parse::()).transpose()?; + self.list_users_with_filter(Some(first), after, false) + .await? + } + (None, Some(last)) => { + let before = before.map(|x| x.parse::()).transpose()?; + self.list_users_with_filter(Some(last), before, true) + .await? + } + _ => self.list_users().await?, + }; + + Ok(users.into_iter().map(|x| x.into()).collect()) + } + + async fn list_invitations_in_page( + &self, + after: Option, + before: Option, + first: Option, + last: Option, + ) -> Result> { + let invitations = match (first, last) { + (Some(first), None) => { + let after = after.map(|x| x.parse::()).transpose()?; + self.list_invitations_with_filter(Some(first), after, false) + .await? + } + (None, Some(last)) => { + let before = before.map(|x| x.parse::()).transpose()?; + self.list_invitations_with_filter(Some(last), before, true) + .await? + } + _ => self.list_invitations().await?, + }; + + Ok(invitations.into_iter().map(|x| x.into()).collect()) + } } fn password_hash(raw: &str) -> password_hash::Result { diff --git a/ee/tabby-webserver/src/service/db/invitations.rs b/ee/tabby-webserver/src/service/db/invitations.rs index 10d2f1a7aafd..32303c5e84c5 100644 --- a/ee/tabby-webserver/src/service/db/invitations.rs +++ b/ee/tabby-webserver/src/service/db/invitations.rs @@ -32,6 +32,32 @@ impl DbConn { Ok(invitations) } + pub async fn list_invitations_with_filter( + &self, + limit: Option, + skip_id: Option, + backwards: bool, + ) -> Result> { + let query = Self::make_pagination_query( + "invitations", + &["id", "email", "code", "created_at"], + limit, + skip_id, + backwards, + ); + + let invitations = self + .conn + .call(move |c| { + let mut stmt = c.prepare(&query)?; + let invit_iter = stmt.query_map([], Invitation::from_row)?; + Ok(invit_iter.filter_map(|x| x.ok()).collect::>()) + }) + .await?; + + Ok(invitations) + } + pub async fn get_invitation_by_code(&self, code: &str) -> Result> { let code = code.to_owned(); let token = self diff --git a/ee/tabby-webserver/src/service/db/mod.rs b/ee/tabby-webserver/src/service/db/mod.rs index 4888cb89d63f..03b0407bf20c 100644 --- a/ee/tabby-webserver/src/service/db/mod.rs +++ b/ee/tabby-webserver/src/service/db/mod.rs @@ -54,6 +54,40 @@ impl DbConn { let res = Self { conn }; Ok(res) } + + fn make_pagination_query( + table_name: &str, + field_names: &[&str], + limit: Option, + skip_id: Option, + backwards: bool, + ) -> String { + let mut source = String::new(); + let mut clause = String::new(); + if backwards { + source += &format!("SELECT * FROM {}", table_name); + if let Some(skip_id) = skip_id { + source += &format!(" WHERE id < {}", skip_id); + } + source += " ORDER BY id DESC"; + if let Some(limit) = limit { + source += &format!(" LIMIT {}", limit); + } + clause += " ORDER BY id ASC"; + } else { + source += table_name; + if let Some(skip_id) = skip_id { + clause += &format!(" WHERE id > {}", skip_id); + } + clause += " ORDER BY id ASC"; + if let Some(limit) = limit { + clause += &format!(" LIMIT {}", limit); + } + } + let fields = field_names.join(", "); + + format!(r#"SELECT {} FROM ({}) {}"#, fields, source, clause) + } } /// db read/write operations for `registration_token` table diff --git a/ee/tabby-webserver/src/service/db/users.rs b/ee/tabby-webserver/src/service/db/users.rs index 53c977f51ec1..b3833065fb39 100644 --- a/ee/tabby-webserver/src/service/db/users.rs +++ b/ee/tabby-webserver/src/service/db/users.rs @@ -1,5 +1,3 @@ -// db read/write operations for `users` table - use anyhow::Result; use chrono::{DateTime, Utc}; use rusqlite::{params, OptionalExtension, Row}; @@ -45,6 +43,7 @@ impl User { impl From for schema::User { fn from(val: User) -> Self { schema::User { + id: juniper::ID::new(val.id.to_string()), email: val.email, is_admin: val.is_admin, auth_token: val.auth_token, @@ -53,6 +52,7 @@ impl From for schema::User { } } +/// db read/write operations for `users` table impl DbConn { pub async fn create_user( &self, @@ -163,6 +163,40 @@ impl DbConn { Ok(users) } + pub async fn list_users_with_filter( + &self, + limit: Option, + skip_id: Option, + backwards: bool, + ) -> Result> { + let query = Self::make_pagination_query( + "users", + &[ + "id", + "email", + "password_encrypted", + "is_admin", + "created_at", + "updated_at", + "auth_token", + ], + limit, + skip_id, + backwards, + ); + + let users = self + .conn + .call(move |c| { + let mut stmt = c.prepare(&query)?; + let user_iter = stmt.query_map([], User::from_row)?; + Ok(user_iter.filter_map(|x| x.ok()).collect::>()) + }) + .await?; + + Ok(users) + } + pub async fn verify_auth_token(&self, token: &str) -> bool { let token = token.to_owned(); let id: Result, _> = self @@ -242,4 +276,226 @@ mod tests { assert_eq!(user.email, new_user.email); assert_ne!(user.auth_token, new_user.auth_token); } + + #[tokio::test] + async fn test_list_users_with_filter() { + let conn = DbConn::new_in_memory().await.unwrap(); + + let empty: Vec = vec![]; + let to_ids = |users: Vec| users.into_iter().map(|u| u.id).collect::>(); + + // empty + // forwards + assert_eq!( + empty, + to_ids( + conn.list_users_with_filter(None, None, false) + .await + .unwrap() + ) + ); + assert_eq!( + empty, + to_ids( + conn.list_users_with_filter(Some(2), None, false) + .await + .unwrap() + ) + ); + assert_eq!( + empty, + to_ids( + conn.list_users_with_filter(None, Some(1), false) + .await + .unwrap() + ) + ); + assert_eq!( + empty, + to_ids( + conn.list_users_with_filter(Some(2), Some(1), false) + .await + .unwrap() + ) + ); + // backwards + assert_eq!( + empty, + to_ids(conn.list_users_with_filter(None, None, true).await.unwrap()) + ); + assert_eq!( + empty, + to_ids( + conn.list_users_with_filter(Some(2), None, true) + .await + .unwrap() + ) + ); + assert_eq!( + empty, + to_ids( + conn.list_users_with_filter(None, Some(1), true) + .await + .unwrap() + ) + ); + assert_eq!( + empty, + to_ids( + conn.list_users_with_filter(Some(1), Some(1), true) + .await + .unwrap() + ) + ); + + let id1 = conn + .create_user("use1@example.com".into(), "123456".into(), false) + .await + .unwrap(); + + // one user + // forwards + assert_eq!( + vec![id1], + to_ids( + conn.list_users_with_filter(None, None, false) + .await + .unwrap() + ) + ); + assert_eq!( + vec![id1], + to_ids( + conn.list_users_with_filter(Some(2), None, false) + .await + .unwrap() + ) + ); + assert_eq!( + empty, + to_ids( + conn.list_users_with_filter(None, Some(1), false) + .await + .unwrap() + ) + ); + assert_eq!( + empty, + to_ids( + conn.list_users_with_filter(Some(2), Some(1), false) + .await + .unwrap() + ) + ); + // backwards + assert_eq!( + vec![id1], + to_ids(conn.list_users_with_filter(None, None, true).await.unwrap()) + ); + assert_eq!( + vec![id1], + to_ids( + conn.list_users_with_filter(Some(2), None, true) + .await + .unwrap() + ) + ); + assert_eq!( + empty, + to_ids( + conn.list_users_with_filter(None, Some(1), true) + .await + .unwrap() + ) + ); + assert_eq!( + empty, + to_ids( + conn.list_users_with_filter(Some(1), Some(1), true) + .await + .unwrap() + ) + ); + + let id2 = conn + .create_user("use2@example.com".into(), "123456".into(), false) + .await + .unwrap(); + let id3 = conn + .create_user("use3@example.com".into(), "123456".into(), false) + .await + .unwrap(); + let id4 = conn + .create_user("use4@example.com".into(), "123456".into(), false) + .await + .unwrap(); + let id5 = conn + .create_user("use5@example.com".into(), "123456".into(), false) + .await + .unwrap(); + + // multiple users + // forwards + assert_eq!( + vec![id1, id2, id3, id4, id5], + to_ids( + conn.list_users_with_filter(None, None, false) + .await + .unwrap() + ) + ); + assert_eq!( + vec![id1, id2], + to_ids( + conn.list_users_with_filter(Some(2), None, false) + .await + .unwrap() + ) + ); + assert_eq!( + vec![id3, id4, id5], + to_ids( + conn.list_users_with_filter(None, Some(2), false) + .await + .unwrap() + ) + ); + assert_eq!( + vec![id3, id4], + to_ids( + conn.list_users_with_filter(Some(2), Some(2), false) + .await + .unwrap() + ) + ); + // backwards + assert_eq!( + vec![id1, id2, id3, id4, id5], + to_ids(conn.list_users_with_filter(None, None, true).await.unwrap()) + ); + assert_eq!( + vec![id4, id5], + to_ids( + conn.list_users_with_filter(Some(2), None, true) + .await + .unwrap() + ) + ); + assert_eq!( + vec![id1, id2, id3], + to_ids( + conn.list_users_with_filter(None, Some(4), true) + .await + .unwrap() + ) + ); + assert_eq!( + vec![id2, id3], + to_ids( + conn.list_users_with_filter(Some(2), Some(4), true) + .await + .unwrap() + ) + ); + } }