Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

build: add tls for mongo in deployment #754

Merged
merged 14 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 19 additions & 26 deletions Adaptors/MongoDB/src/ServiceCollectionExt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
// along with this program. If not, see <http://www.gnu.org/licenses/>.

using System;
using System.Security.Cryptography.X509Certificates;
using System.Security.Authentication;

using ArmoniK.Api.Common.Utils;
using ArmoniK.Core.Adapters.MongoDB.Common;
Expand All @@ -40,6 +40,8 @@
using MongoDB.Driver.Core.Configuration;
using MongoDB.Driver.Core.Extensions.DiagnosticSources;

using static ArmoniK.Core.Utils.CertificateValidator;

namespace ArmoniK.Core.Adapters.MongoDB;

public static class ServiceCollectionExt
Expand Down Expand Up @@ -97,7 +99,6 @@ public static IServiceCollection AddMongoClient(this IServiceCollection services
services.AddOption(configuration,
Options.MongoDB.SettingSection,
out mongoOptions);

using var _ = logger.BeginNamedScope("MongoDB configuration",
("host", mongoOptions.Host),
("port", mongoOptions.Port));
Expand Down Expand Up @@ -132,30 +133,6 @@ public static IServiceCollection AddMongoClient(this IServiceCollection services
logger.LogTrace("No credentials provided");
}

if (!string.IsNullOrEmpty(mongoOptions.CAFile))
{
var localTrustStore = new X509Store(StoreName.Root);
var certificateCollection = new X509Certificate2Collection();
try
{
certificateCollection.ImportFromPemFile(mongoOptions.CAFile);
localTrustStore.Open(OpenFlags.ReadWrite);
localTrustStore.AddRange(certificateCollection);
logger.LogTrace("Imported mongodb certificate from file {path}",
mongoOptions.CAFile);
}
catch (Exception ex)
{
logger.LogError("Root certificate import failed: {error}",
ex.Message);
throw;
}
finally
{
localTrustStore.Close();
}
}

string connectionString;
if (string.IsNullOrEmpty(mongoOptions.User) || string.IsNullOrEmpty(mongoOptions.Password))
{
Expand All @@ -182,13 +159,28 @@ public static IServiceCollection AddMongoClient(this IServiceCollection services
}

var settings = MongoClientSettings.FromUrl(new MongoUrl(connectionString));

// Configure the connection settings
settings.AllowInsecureTls = mongoOptions.AllowInsecureTls;
settings.UseTls = mongoOptions.Tls;
settings.DirectConnection = mongoOptions.DirectConnection;
settings.Scheme = ConnectionStringScheme.MongoDB;
settings.MaxConnectionPoolSize = mongoOptions.MaxConnectionPoolSize;
settings.ServerSelectionTimeout = mongoOptions.ServerSelectionTimeout;
settings.ReplicaSetName = mongoOptions.ReplicaSet;

if (!string.IsNullOrEmpty(mongoOptions.CAFile))
{
var validationCallback = CreateCallback(mongoOptions.CAFile,
logger);

settings.SslSettings = new SslSettings
{
EnabledSslProtocols = SslProtocols.Tls12,
ServerCertificateValidationCallback = validationCallback,
};
}

settings.ClusterConfigurator = cb =>
{
//cb.Subscribe<CommandStartedEvent>(e => logger.LogTrace("{CommandName} - {Command}",
Expand All @@ -197,6 +189,7 @@ public static IServiceCollection AddMongoClient(this IServiceCollection services
cb.Subscribe(new DiagnosticsActivityEventSubscriber());
};


var client = new MongoClient(settings);

services.AddSingleton<IMongoClient>(client);
Expand Down
135 changes: 135 additions & 0 deletions Utils/src/ServerCertificateValidator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
// This file is part of the ArmoniK project
//
// Copyright (C) ANEO, 2021-2025. All rights reserved.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY, without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

using System.IO;
using System.Linq;
using System.Net.Security;
using System.Security.Cryptography.X509Certificates;

using Microsoft.Extensions.Logging;

namespace ArmoniK.Core.Utils;

/// <summary>
/// Provides utilities for validating SSL/TLS certificates.
/// </summary>
public static class CertificateValidator
{
/// <summary>
/// Creates a callback function to validate SSL/TLS certificates during a secure connection.
/// </summary>
/// <param name="logger">The logger to use for logging validation details.</param>
/// <param name="authority">The root certificate authority to trust during validation.</param>
/// <returns>
/// A <see cref="RemoteCertificateValidationCallback" /> delegate that performs SSL/TLS certificate validation.
/// </returns>
public static RemoteCertificateValidationCallback ValidationCallback(ILogger logger,
X509Certificate2 authority)
=> (sender,
certificate,
chain,
sslPolicyErrors) =>
{
if (certificate == null || chain == null)
{
logger.LogWarning("Certificate or certificate chain is null");
return false;
}

// If there is any error other than untrusted root or partial chain, fail the validation
if ((sslPolicyErrors & ~SslPolicyErrors.RemoteCertificateChainErrors) != 0)
{
logger.LogDebug("SSL validation failed with errors: {sslPolicyErrors}",
sslPolicyErrors);
return false;
}

if (certificate == null)
{
logger.LogDebug("Certificate is null!");
return false;
}

if (chain == null)
{
logger.LogDebug("Certificate chain is null!");
return false;
}

// If there is any error other than untrusted root or partial chain, fail the validation
if (chain.ChainStatus.Any(status => status.Status is not X509ChainStatusFlags.UntrustedRoot and not X509ChainStatusFlags.PartialChain))
{
logger.LogDebug("SSL validation failed with chain status: {chainStatus}",
chain.ChainStatus);
return false;
}

var cert = new X509Certificate2(certificate);
chain.ChainPolicy.RevocationMode = X509RevocationMode.NoCheck;
chain.ChainPolicy.VerificationFlags = X509VerificationFlags.AllowUnknownCertificateAuthority;

chain.ChainPolicy.ExtraStore.Add(authority);
if (!chain.Build(cert))
{
return false;
}

var isTrusted = chain.ChainElements.Any(x => x.Certificate.Thumbprint == authority.Thumbprint);
if (isTrusted)
{
logger.LogInformation("SSL validation succeeded");
}
else
{
logger.LogInformation("SSL validation failed with errors: {sslPolicyErrors}",
sslPolicyErrors);
}

return isTrusted;
};

/// <summary>
/// Creates a certificate validation callback from a Certificate Authority (CA) file.
/// </summary>
/// <param name="caFilePath">The file path to the CA certificate.</param>
/// <param name="logger">The logger to use for logging validation details.</param>
/// <returns>
/// A <see cref="RemoteCertificateValidationCallback" /> delegate that performs SSL/TLS certificate validation.
/// </returns>
/// <exception cref="FileNotFoundException">
/// Thrown if the specified CA certificate file is not found.
/// </exception>
public static RemoteCertificateValidationCallback CreateCallback(string caFilePath,
ILogger logger)
{
if (!File.Exists(caFilePath))
{
logger.LogError("CA certificate Mongo file not found at {path}",
caFilePath);
throw new FileNotFoundException("CA certificate Mongo file not found",
caFilePath);
}

var content = File.ReadAllText(caFilePath);
var authority = X509Certificate2.CreateFromPem(content);
logger.LogInformation("Loaded CA certificate from file {path}",
caFilePath);
var callback = ValidationCallback(logger,
authority);
return callback;
}
}
4 changes: 4 additions & 0 deletions terraform/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ module "submitter" {
generated_env_vars = local.environment
log_driver = module.fluenbit.log_driver
volumes = local.volumes
mounts = module.database.core_mounts
}

module "compute_plane" {
Expand All @@ -134,6 +135,7 @@ module "compute_plane" {
volumes = local.volumes
network = docker_network.armonik.id
log_driver = module.fluenbit.log_driver
mounts = module.database.core_mounts
}

module "metrics_exporter" {
Expand All @@ -143,6 +145,7 @@ module "metrics_exporter" {
network = docker_network.armonik.id
generated_env_vars = local.environment
log_driver = module.fluenbit.log_driver
mounts = module.database.core_mounts
}

module "partition_metrics_exporter" {
Expand All @@ -153,6 +156,7 @@ module "partition_metrics_exporter" {
generated_env_vars = local.environment
metrics_env_vars = module.metrics_exporter.metrics_env_vars
log_driver = module.fluenbit.log_driver
mounts = module.database.core_mounts
}

module "ingress" {
Expand Down
4 changes: 4 additions & 0 deletions terraform/modules/compute_plane/inputs.tf
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ variable "volumes" {
type = map(string)
}

variable "mounts" {
type = map(string)
}

variable "replica_counter" {
type = number
}
Expand Down
8 changes: 8 additions & 0 deletions terraform/modules/compute_plane/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -74,5 +74,13 @@ resource "docker_container" "polling_agent" {
}
}

dynamic "upload" {
for_each = var.mounts
content {
source = upload.value
file = upload.key
}
}

depends_on = [docker_container.worker]
}
4 changes: 4 additions & 0 deletions terraform/modules/monitoring/metrics/inputs.tf
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ variable "generated_env_vars" {
type = map(string)
}

variable "mounts" {
type = map(string)
}

variable "exposed_port" {
type = number
default = 5002
Expand Down
8 changes: 8 additions & 0 deletions terraform/modules/monitoring/metrics/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,12 @@ resource "docker_container" "metrics" {
internal = 1080
external = var.exposed_port
}

dynamic "upload" {
for_each = var.mounts
content {
source = upload.value
file = upload.key
}
}
}
4 changes: 4 additions & 0 deletions terraform/modules/monitoring/partition_metrics/inputs.tf
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ variable "generated_env_vars" {
type = map(string)
}

variable "mounts" {
type = map(string)
}

variable "metrics_env_vars" {
type = map(string)
}
Expand Down
8 changes: 8 additions & 0 deletions terraform/modules/monitoring/partition_metrics/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,12 @@ resource "docker_container" "partition_metrics" {
internal = 1080
external = var.exposed_port
}

dynamic "upload" {
for_each = var.mounts
content {
source = upload.value
file = upload.key
}
}
}
Loading
Loading