Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Domain verification chanllenge refactor #154

Merged
merged 9 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions database/migrations/0014_domain_verifications.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
CREATE TABLE domain_verifications(
domain_name TEXT PRIMARY KEY,
app_id TEXT NOT NULL,
code TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL,
finished_at TIMESTAMPTZ
);

CREATE INDEX domain_verifications_app_id_idx ON domain_verifications(app_id);
3 changes: 3 additions & 0 deletions database/src/tables/domain_verifications/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pub mod select;
pub mod table_struct;
pub mod update;
39 changes: 39 additions & 0 deletions database/src/tables/domain_verifications/select.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
use crate::{
db::Db,
structs::db_error::DbError,
tables::domain_verifications::table_struct::{
DomainVerification, DOMAIN_VERIFICATIONS_TABLE_NAME,
},
};
use sqlx::query_as;

impl Db {
pub async fn get_domain_verifications_by_app_id(
&self,
app_id: &String,
) -> Result<Vec<DomainVerification>, DbError> {
let query = format!("SELECT * FROM {DOMAIN_VERIFICATIONS_TABLE_NAME} WHERE app_id = $1 ORDER BY created_at DESC");
let typed_query = query_as::<_, DomainVerification>(&query);

return typed_query
.bind(&app_id)
.fetch_all(&self.connection_pool)
.await
.map_err(|e| e.into());
}

pub async fn get_domain_verification_by_domain_name(
&self,
domain_name: &String,
) -> Result<Option<DomainVerification>, DbError> {
let query =
format!("SELECT * FROM {DOMAIN_VERIFICATIONS_TABLE_NAME} WHERE domain_name = $1");
let typed_query = query_as::<_, DomainVerification>(&query);

return typed_query
.bind(&domain_name)
.fetch_optional(&self.connection_pool)
.await
.map_err(|e| e.into());
}
}
29 changes: 29 additions & 0 deletions database/src/tables/domain_verifications/table_struct.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use sqlx::{
postgres::PgRow,
types::chrono::{DateTime, Utc},
FromRow, Row,
};

pub const DOMAIN_VERIFICATIONS_TABLE_NAME: &str = "domain_verifications";
pub const DOMAIN_VERIFICATIONS_KEYS: &str = "domain_name, app_id, code, created_at, finished_at";

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct DomainVerification {
pub domain_name: String,
pub app_id: String,
pub code: String,
pub created_at: DateTime<Utc>,
pub finished_at: Option<DateTime<Utc>>,
}

impl FromRow<'_, PgRow> for DomainVerification {
fn from_row(row: &sqlx::postgres::PgRow) -> std::result::Result<Self, sqlx::Error> {
Ok(DomainVerification {
domain_name: row.get("domain_name"),
app_id: row.get("app_id"),
code: row.get("code"),
created_at: row.get("created_at"),
finished_at: row.get("finished_at"),
})
}
}
52 changes: 52 additions & 0 deletions database/src/tables/domain_verifications/update.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use super::table_struct::{DOMAIN_VERIFICATIONS_KEYS, DOMAIN_VERIFICATIONS_TABLE_NAME};
use crate::db::Db;
use crate::structs::db_error::DbError;
use crate::tables::utils::get_current_datetime;
use sqlx::query;

impl Db {
pub async fn create_new_domain_verification_entry(
&self,
domain_name: &String,
app_id: &String,
code: &String,
) -> Result<(), DbError> {
let query_body = format!(
"INSERT INTO {DOMAIN_VERIFICATIONS_TABLE_NAME} ({DOMAIN_VERIFICATIONS_KEYS}) VALUES ($1, $2, $3, $4, NULL)"
);

let query_result = query(&query_body)
.bind(&domain_name)
.bind(&app_id)
.bind(&code)
.bind(&get_current_datetime())
.execute(&self.connection_pool)
.await;

match query_result {
Ok(_) => Ok(()),
Err(e) => Err(e).map_err(|e| e.into()),
}
}

pub async fn finish_domain_verification(
&self,
tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
domain_name: &String,
) -> Result<(), DbError> {
let query_body = format!(
"UPDATE {DOMAIN_VERIFICATIONS_TABLE_NAME} SET finished_at = $1 WHERE domain_name = $2"
);

let query_result = query(&query_body)
.bind(&get_current_datetime())
.bind(&domain_name)
.execute(&mut **tx)
.await;

match query_result {
Ok(_) => Ok(()),
Err(e) => Err(e).map_err(|e| e.into()),
}
}
}
1 change: 1 addition & 0 deletions database/src/tables/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub mod ip_addresses;
pub mod public_keys;
pub mod registered_app;
// pub mod requests;
pub mod domain_verifications;
pub mod session_public_keys;
pub mod sessions;
pub mod team;
Expand Down
3 changes: 2 additions & 1 deletion database/src/tables/registered_app/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ impl Db {

pub async fn add_new_whitelisted_domain(
&self,
tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
app_id: &str,
domain: &str,
) -> Result<(), DbError> {
Expand All @@ -62,7 +63,7 @@ impl Db {
let query_result = query(&query_body)
.bind(domain)
.bind(app_id)
.execute(&self.connection_pool)
.execute(&mut **tx)
.await;

match query_result {
Expand Down
2 changes: 1 addition & 1 deletion server/bindings/CloudApiErrors.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.

export type CloudApiErrors = "TeamDoesNotExist" | "UserDoesNotExist" | "CloudFeatureDisabled" | "InsufficientPermissions" | "TeamHasNoRegisteredApps" | "DatabaseError" | "MaximumUsersPerTeamReached" | "UserAlreadyBelongsToTheTeam" | "IncorrectPassword" | "AccessTokenFailure" | "RefreshTokenFailure" | "AppAlreadyExists" | "MaximumAppsPerTeamReached" | "TeamAlreadyExists" | "PersonalTeamAlreadyExists" | "EmailAlreadyExists" | "InternalServerError" | "UserDoesNotBelongsToTheTeam" | "InvalidName" | "UnauthorizedOriginError" | "AppDoesNotExist" | "UserAlreadyInvitedToTheTeam" | "MaximumInvitesPerTeamReached" | "InviteNotFound" | "ActionForbiddenForPersonalTeam" | "InviteDoesNotExist" | "InvalidPaginationCursor" | "InvalidVerificationCode" | "InvalidDomainName" | "DomainAlreadyVerified" | "DomainVerificationFailure" | "DomainNotFound";
export type CloudApiErrors = "TeamDoesNotExist" | "UserDoesNotExist" | "CloudFeatureDisabled" | "InsufficientPermissions" | "TeamHasNoRegisteredApps" | "DatabaseError" | "MaximumUsersPerTeamReached" | "UserAlreadyBelongsToTheTeam" | "IncorrectPassword" | "AccessTokenFailure" | "RefreshTokenFailure" | "AppAlreadyExists" | "MaximumAppsPerTeamReached" | "TeamAlreadyExists" | "PersonalTeamAlreadyExists" | "EmailAlreadyExists" | "InternalServerError" | "UserDoesNotBelongsToTheTeam" | "InvalidName" | "UnauthorizedOriginError" | "AppDoesNotExist" | "UserAlreadyInvitedToTheTeam" | "MaximumInvitesPerTeamReached" | "InviteNotFound" | "ActionForbiddenForPersonalTeam" | "InviteDoesNotExist" | "InvalidPaginationCursor" | "InvalidVerificationCode" | "InvalidDomainName" | "DomainAlreadyVerified" | "DomainVerificationFailure" | "DomainNotFound" | "DomainVerificationNotStarted";
71 changes: 54 additions & 17 deletions server/src/http/cloud/domains/verify_domain_finish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@ use crate::{
env::is_env_production,
http::cloud::utils::{custom_validate_domain_name, custom_validate_uuid},
middlewares::auth_middleware::UserId,
structs::{
cloud::api_cloud_errors::CloudApiErrors,
session_cache::{ApiSessionsCache, SessionCache, SessionsCacheKey},
},
structs::cloud::api_cloud_errors::CloudApiErrors,
};
use anyhow::bail;
use axum::{extract::State, http::StatusCode, Extension, Json};
Expand All @@ -33,7 +30,6 @@ pub struct HttpVerifyDomainFinishResponse {}

pub async fn verify_domain_finish(
State(db): State<Arc<Db>>,
State(sessions_cache): State<Arc<ApiSessionsCache>>,
State(dns_resolver): State<Arc<DnsResolver>>,
Extension(user_id): Extension<UserId>,
Json(request): Json<HttpVerifyDomainFinishRequest>,
Expand Down Expand Up @@ -105,26 +101,36 @@ pub async fn verify_domain_finish(
));
}

// Get session data
let sessions_key = SessionsCacheKey::DomainVerification(domain_name.clone()).to_string();
let session_data = match sessions_cache.get(&sessions_key) {
Some(SessionCache::VerifyDomain(session)) => session,
_ => {
// Get challenge data
let domain_verification_challenge = match db
.get_domain_verification_by_domain_name(&domain_name)
.await
{
Ok(Some(challenge)) => challenge,
Ok(None) => {
return Err((
StatusCode::BAD_REQUEST,
CloudApiErrors::DomainVerificationNotStarted.to_string(),
))
}
Err(err) => {
error!("Failed to get domain verification challenge: {:?}", err);
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
CloudApiErrors::InternalServerError.to_string(),
CloudApiErrors::DatabaseError.to_string(),
));
}
};

// Remove leftover session data
sessions_cache.remove(&sessions_key);

// Validate the code
// Attempt to resolve the TXT records for the given domain, only on PROD
if is_env_production() {
if let Err(err) =
check_verification_code(&dns_resolver, &domain_name, &session_data.code).await
if let Err(err) = check_verification_code(
&dns_resolver,
&domain_name,
&domain_verification_challenge.code,
)
.await
{
error!("Failed to verify domain: {:?}, err: {:?}", domain_name, err);
return Err((
Expand All @@ -135,17 +141,48 @@ pub async fn verify_domain_finish(
}

// Add domain to whitelist
let mut tx = db.connection_pool.begin().await.unwrap();

if let Err(err) = db
.add_new_whitelisted_domain(&request.app_id, &domain_name)
.add_new_whitelisted_domain(&mut tx, &request.app_id, &domain_name)
.await
{
let _ = tx
.rollback()
.await
.map_err(|err| error!("Failed to rollback transaction: {:?}", err));

error!("Failed to add domain to whitelist: {:?}", err);
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
CloudApiErrors::DatabaseError.to_string(),
));
}

// Update domain verification entry
if let Err(err) = db.finish_domain_verification(&mut tx, &domain_name).await {
let _ = tx
.rollback()
.await
.map_err(|err| error!("Failed to rollback transaction: {:?}", err));

error!("Failed to finish domain verification: {:?}", err);
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
CloudApiErrors::DatabaseError.to_string(),
));
}

// Commit transaction
if let Err(err) = tx.commit().await {
error!("Failed to commit transaction: {:?}", err);

return Err((
StatusCode::INTERNAL_SERVER_ERROR,
CloudApiErrors::DatabaseError.to_string(),
));
}

Ok(Json(HttpVerifyDomainFinishResponse {}))
}

Expand Down
54 changes: 32 additions & 22 deletions server/src/http/cloud/domains/verify_domain_start.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
use crate::{
http::cloud::utils::{custom_validate_domain_name, custom_validate_uuid},
middlewares::auth_middleware::UserId,
structs::{
cloud::api_cloud_errors::CloudApiErrors,
session_cache::{ApiSessionsCache, DomainVerification, SessionCache, SessionsCacheKey},
},
utils::get_timestamp_in_milliseconds,
structs::cloud::api_cloud_errors::CloudApiErrors,
};
use axum::{extract::State, http::StatusCode, Extension, Json};
use database::{db::Db, structs::privilege_level::PrivilegeLevel};
Expand Down Expand Up @@ -33,7 +29,6 @@ pub struct HttpVerifyDomainStartResponse {

pub async fn verify_domain_start(
State(db): State<Arc<Db>>,
State(sessions_cache): State<Arc<ApiSessionsCache>>,
Extension(user_id): Extension<UserId>,
Json(request): Json<HttpVerifyDomainStartRequest>,
) -> Result<Json<HttpVerifyDomainStartResponse>, (StatusCode, String)> {
Expand Down Expand Up @@ -104,24 +99,39 @@ pub async fn verify_domain_start(
));
}

// Generate verification code
let verification_code =
format!("TXT Nc verification code {}", uuid7::uuid7().to_string()).to_string();
// Check if challenge already exists
let verification_code = match db
.get_domain_verification_by_domain_name(&domain_name)
.await
{
Ok(Some(challenge)) => challenge.code,
Ok(None) => {
// Challenge does not exist, generate new code
let code =
format!("TXT NCC verification code {}", uuid7::uuid7().to_string()).to_string();

// Save to cache
let sessions_key = SessionsCacheKey::DomainVerification(domain_name.clone()).to_string();
// Remove leftover session data
sessions_cache.remove(&sessions_key);
// Save challenge to the database
if let Err(err) = db
.create_new_domain_verification_entry(&domain_name, &request.app_id, &code)
.await
{
error!("Failed to save challenge to the database: {:?}", err);
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
CloudApiErrors::DatabaseError.to_string(),
));
}

sessions_cache.set(
sessions_key,
SessionCache::VerifyDomain(DomainVerification {
domain_name: domain_name.clone(),
code: verification_code.clone(),
created_at: get_timestamp_in_milliseconds(),
}),
None,
);
code
}
Err(err) => {
error!("Failed to check if challenge exists: {:?}", err);
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
CloudApiErrors::DatabaseError.to_string(),
));
}
};

Ok(Json(HttpVerifyDomainStartResponse {
code: verification_code,
Expand Down
1 change: 1 addition & 0 deletions server/src/structs/cloud/api_cloud_errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@ pub enum CloudApiErrors {
DomainAlreadyVerified,
DomainVerificationFailure,
DomainNotFound,
DomainVerificationNotStarted,
}
Loading
Loading