Skip to content

Commit

Permalink
Add support for client cert auth in gRPC
Browse files Browse the repository at this point in the history
  • Loading branch information
radovanZRasa committed Jun 18, 2024
1 parent 5410a90 commit 6cf9bdb
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 25 deletions.
4 changes: 2 additions & 2 deletions rasa_sdk/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def main_from_args(args):
args.port,
args.ssl_certificate,
args.ssl_keyfile,
args.ssl_password,
args.ssl_ca_file,
args.endpoints,
)
)
Expand All @@ -47,7 +47,7 @@ def main_from_args(args):


def main():
# Running as standalone python application
"""Runs the action server as standalone application."""
arg_parser = create_argument_parser()
cmdline_args = arg_parser.parse_args()

Expand Down
32 changes: 27 additions & 5 deletions rasa_sdk/cli/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,31 @@
from rasa_sdk.constants import DEFAULT_SERVER_PORT, DEFAULT_ENDPOINTS_PATH


def action_arg(action):
if "/" in action:
def action_arg(actions_module_path: str) -> str:
"""Validate the action module path.
Valid action module path is python module, so it should not contain a slash.
Args:
actions_module_path: Path to the actions python module.
Returns:
actions_module_path: If provided module path is valid.
Raises:
argparse.ArgumentTypeError: If the module path is invalid.
"""
if "/" in actions_module_path:
raise argparse.ArgumentTypeError(
"Invalid actions format. Actions file should be a python module "
"and passed with module notation (e.g. directory.actions)."
)
else:
return action
return actions_module_path


def add_endpoint_arguments(parser):
def add_endpoint_arguments(parser: argparse.ArgumentParser) -> None:
"""Add all the arguments to the argument parser."""
parser.add_argument(
"-p",
"--port",
Expand Down Expand Up @@ -47,7 +61,15 @@ def add_endpoint_arguments(parser):
"--ssl-password",
default=None,
help="If your ssl-keyfile is protected by a password, you can specify it "
"using this paramer.",
"using this parameter. "
"Not supported in grpc mode.",
)
parser.add_argument(
"--ssl-ca-file",
default=None,
help="If you want to authenticate the client using a certificate, you can "
"specify the CA certificate of the client using this parameter. "
"Supported only in grpc mode.",
)
parser.add_argument(
"--auto-reload",
Expand Down
29 changes: 15 additions & 14 deletions rasa_sdk/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import signal

import asyncio
import ssl

import grpc
import logging
Expand Down Expand Up @@ -35,7 +34,10 @@
get_tracer_and_context,
TracerProvider,
)
from rasa_sdk.utils import check_version_compatibility, number_of_sanic_workers
from rasa_sdk.utils import (
check_version_compatibility,
number_of_sanic_workers, file_as_bytes,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -84,7 +86,7 @@ async def webhook(

body = ActionExecutionFailed(
action_name=e.action_name, message=e.message
).model_dump()
).model_dump_json()
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(body)
return action_webhook_pb2.WebhookResponse()
Expand Down Expand Up @@ -147,7 +149,7 @@ async def run_grpc(
port: int = DEFAULT_SERVER_PORT,
ssl_certificate: Optional[str] = None,
ssl_keyfile: Optional[str] = None,
ssl_password: Optional[str] = None,
ssl_ca_file: Optional[str] = None,
endpoints: str = DEFAULT_ENDPOINTS_PATH,
):
"""Start a gRPC server to handle incoming action requests.
Expand All @@ -157,7 +159,7 @@ async def run_grpc(
port: Port to start the server on.
ssl_certificate: File path to the SSL certificate.
ssl_keyfile: File path to the SSL key file.
ssl_password: Password for the SSL key file.
ssl_ca_file: File path to the SSL CA certificate file.
endpoints: Path to the endpoints file.
"""
workers = number_of_sanic_workers()
Expand All @@ -170,22 +172,21 @@ async def run_grpc(
action_webhook_pb2_grpc.add_ActionServerWebhookServicer_to_server(
GRPCActionServerWebhook(executor, tracer_provider), server
)

ca_cert = file_as_bytes(ssl_ca_file) if ssl_ca_file else None

if ssl_certificate and ssl_keyfile:
# Use SSL/TLS if certificate and key are provided
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_context.load_cert_chain(
ssl_certificate,
keyfile=ssl_keyfile,
password=ssl_password if ssl_password else None,
)
grpc.ssl_channel_credentials()
private_key = open(ssl_keyfile, "rb").read()
certificate_chain = open(ssl_certificate, "rb").read()
private_key = file_as_bytes(ssl_keyfile)
certificate_chain = file_as_bytes(ssl_certificate)
logger.info(f"Starting gRPC server with SSL support on port {port}")
server.add_secure_port(
f"[::]:{port}",
server_credentials=grpc.ssl_server_credentials(
[(private_key, certificate_chain)]
private_key_certificate_chain_pairs=[(private_key, certificate_chain)],
root_certificates = ca_cert,
require_client_auth = True if ca_cert else False,
),
)
else:
Expand Down
32 changes: 28 additions & 4 deletions rasa_sdk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@


class Element(dict):
"""Represents an element in a list of elements in a rich message."""
__acceptable_keys = ["title", "item_url", "image_url", "subtitle", "buttons"]

def __init__(self, *args, **kwargs):
"""Initializes an element in a list of elements in a rich message."""
kwargs = {
key: value for key, value in kwargs.items() if key in self.__acceptable_keys
}
Expand All @@ -43,6 +45,7 @@ def __init__(self, *args, **kwargs):


class Button(dict):
"""Represents a button in a rich message."""
pass


Expand Down Expand Up @@ -257,12 +260,12 @@ def check_version_compatibility(rasa_version: Optional[Text]) -> None:
rasa and rasa_sdk.
Args:
rasa_version - A string containing the version of rasa that
is making the call to the action server.
rasa_version: A string containing the version of rasa that
is making the call to the action server.
Raises:
Warning - The version of rasa version unknown or not compatible with
this version of rasa_sdk.
Warning: The version of rasa version unknown or not compatible with
this version of rasa_sdk.
"""
# Check for versions of Rasa that are too old to report their version number
if rasa_version is None:
Expand Down Expand Up @@ -386,3 +389,24 @@ def read_yaml_file(filename: Union[Text, Path]) -> Dict[Text, Any]:
return read_yaml(read_file(filename, DEFAULT_ENCODING))
except (YAMLError, DuplicateKeyError) as e:
raise YamlSyntaxException(filename, e)


def file_as_bytes(file_path: Text) -> bytes:
"""Read in a file as a byte array.
Args:
file_path: Path to the file to read.
Returns:
The file content as a byte array.
Raises:
FileNotFoundException: If the file does not exist.
"""
try:
with open(file_path, "rb") as f:
return f.read()
except FileNotFoundError:
raise FileNotFoundException(
f"Failed to read file, " f"'{os.path.abspath(file_path)}' does not exist."
)

0 comments on commit 6cf9bdb

Please sign in to comment.