diff --git a/deploy/mlflow-triton-plugin/README.md b/deploy/mlflow-triton-plugin/README.md index c3589bba04..c011194299 100644 --- a/deploy/mlflow-triton-plugin/README.md +++ b/deploy/mlflow-triton-plugin/README.md @@ -57,7 +57,7 @@ python setup.py install ## Quick Start In this documentation, we will use the files in `examples` to showcase how -the plugin interacts with Triton Infernce Server. The `onnx_float32_int32_int32` +the plugin interacts with Triton Inference Server. The `onnx_float32_int32_int32` model in `examples` is a simple model that takes two float32 inputs, INPUT0 and INPUT1, with shape [-1, 16], and produces two int32 outputs, OUTPUT0 and OUTPUT1, where OUTPUT0 is the element-wise summation of INPUT0 and INPUT1 and @@ -85,7 +85,7 @@ The MLFlow ONNX built-in functionalities can be used to publish `onnx` flavor models to MLFlow directly, and the MLFlow Triton plugin will prepare the model to the format expected by Triton. You may also log [`config.pbtxt`](https://github.com/triton-inference-server/server/blob/main/docs/protocol/extension_model_configuration.md) -as additonal artifact which Triton will be used to serve the model. Otherwise, +as additional artifact which Triton will be used to serve the model. Otherwise, the server should be run with auto-complete feature enabled (`--strict-model-config=false`) to generate the model configuration. diff --git a/deploy/mlflow-triton-plugin/mlflow_triton/config.py b/deploy/mlflow-triton-plugin/mlflow_triton/config.py index e9c72a7ceb..d4fce37cfa 100644 --- a/deploy/mlflow-triton-plugin/mlflow_triton/config.py +++ b/deploy/mlflow-triton-plugin/mlflow_triton/config.py @@ -24,6 +24,9 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import os +import re +from collections import namedtuple +from mlflow.exceptions import MlflowException class Config(dict): @@ -34,7 +37,101 @@ def __init__(self): self['triton_model_repo'] = os.environ.get('TRITON_MODEL_REPO') if self['triton_model_repo'].startswith('s3://'): + self.s3_regex = re.compile( + 's3://(http://|https://|)([0-9a-zA-Z\\-.]+):([0-9]+)/' + '([0-9a-z.\\-]+)(((/[0-9a-zA-Z.\\-_]+)*)?)') + + uri = self.parse_path(self['triton_model_repo']) + if uri.protocol == "https://": + protocol = "https://" + else: + protocol = "http://" + endpoint_url = None + if uri.host_name != "" and uri.host_port != "": + endpoint_url = '{}{}:{}'.format( + protocol, uri.host_name, uri.host_port) + import boto3 - self['s3'] = boto3.client('s3') # boto3 handles AWS credentials - self['s3_bucket'] = self['triton_model_repo'].replace('s3://', '').replace('/', '') + # boto3 handles AWS credentials + self['s3'] = boto3.client( + 's3', endpoint_url=endpoint_url) + self['s3_bucket'] = uri.bucket + self['s3_prefix'] = uri.prefix + self['triton_model_repo'] = 's3://{}'.format( + os.path.join(uri.bucket, uri.prefix)) + + def parse_path(self, path): + # Cleanup extra slashes + clean_path = self.clean_path(path) + + # Get the bucket name and the object path. Return error if path is malformed + match = self.s3_regex.fullmatch(clean_path) + S3URI = namedtuple( + "S3URI", ["protocol", "host_name", "host_port", "bucket", "prefix"]) + if match: + uri = S3URI(*match.group(1, 2, 3, 4, 5)) + if uri.prefix and uri.prefix[0] == '/': + uri = uri._replace(prefix=uri.prefix[1:]) + else: + bucket_start = clean_path.find("s3://") + len("s3://") + bucket_end = clean_path.find("/", bucket_start) + + # If there isn't a slash, the address has only the bucket + if bucket_end > bucket_start: + bucket = clean_path[bucket_start:bucket_end] + prefix = clean_path[bucket_end + 1:] + else: + bucket = clean_path[bucket_start:] + prefix = "" + uri = S3URI("", "", "", bucket, prefix) + + if uri.bucket == "": + raise MlflowException("No bucket name found in path: " + path) + + return uri + + def clean_path(self, s3_path): + # Must handle paths with s3 prefix + start = s3_path.find("s3://") + path = "" + if start != -1: + path = s3_path[start + len("s3://"):] + clean_path = ("s3://") + else: + path = s3_path + clean_path = "" + + # Must handle paths with https:// or http:// prefix + https_start = path.find("https://") + if https_start != -1: + path = path[https_start + len("https://"):] + clean_path += "https://" + else: + http_start = path.find("http://") + if http_start != -1: + path = path[http_start + len("http://"):] + clean_path += "http://" + + # Remove trailing slashes + rtrim_length = len(path.rstrip('/')) + if rtrim_length == 0: + raise MlflowException("Invalid bucket name: '" + path + "'") + + # Remove leading slashes + ltrim_length = len(path) - len(path.lstrip('/')) + if ltrim_length == len(path): + raise MlflowException("Invalid bucket name: '" + path + "'") + + # Remove extra internal slashes + true_path = path[ltrim_length:rtrim_length + 1] + previous_slash = False + for i in range(len(true_path)): + if true_path[i] == '/': + if not previous_slash: + clean_path += true_path[i] + previous_slash = True + else: + clean_path += true_path[i] + previous_slash = False + return clean_path diff --git a/deploy/mlflow-triton-plugin/mlflow_triton/deployments.py b/deploy/mlflow-triton-plugin/mlflow_triton/deployments.py index e2091dcbc8..168d46399d 100644 --- a/deploy/mlflow-triton-plugin/mlflow_triton/deployments.py +++ b/deploy/mlflow-triton-plugin/mlflow_triton/deployments.py @@ -57,7 +57,8 @@ def __init__(self, uri): super(TritonPlugin, self).__init__(target_uri=uri) self.server_config = Config() triton_url, self.triton_model_repo = self._get_triton_server_config() - self.supported_flavors = ['triton', 'onnx'] # need to add other flavors + # need to add other flavors + self.supported_flavors = ['triton', 'onnx'] # URL cleaning for constructing Triton client ssl = False if triton_url.startswith("http://"): @@ -189,8 +190,9 @@ def list_deployments(self): if 's3' in self.server_config: meta_dict = ast.literal_eval(self.server_config['s3'].get_object( Bucket=self.server_config['s3_bucket'], - Key=d['name'] + f'/{_MLFLOW_META_FILENAME}', - )['Body'].read().decode('utf-8')) + Key=os.path.join( + self.server_config['s3_prefix'], d['name'], _MLFLOW_META_FILENAME), + )['Body'].read().decode('utf-8')) elif os.path.isfile(mlflow_meta_path): meta_dict = self._get_mlflow_meta_dict(d['name']) else: @@ -278,11 +280,12 @@ def _generate_mlflow_meta_file(self, name, flavor, model_uri): self.server_config['s3'].put_object( Body=json.dumps(meta_dict, indent=4).encode('utf-8'), Bucket=self.server_config["s3_bucket"], - Key=f'{name}/{_MLFLOW_META_FILENAME}', + Key=os.path.join( + self.server_config['s3_prefix'], name, _MLFLOW_META_FILENAME), ) else: with open(os.path.join(triton_deployment_dir, _MLFLOW_META_FILENAME), - "w") as outfile: + "w") as outfile: json.dump(meta_dict, outfile, indent=4) print("Saved", _MLFLOW_META_FILENAME, "to", triton_deployment_dir) @@ -294,7 +297,8 @@ def _get_mlflow_meta_dict(self, name): if 's3' in self.server_config: mlflow_meta_dict = ast.literal_eval(self.server_config['s3'].get_object( Bucket=self.server_config['s3_bucket'], - Key=f'{name}/{_MLFLOW_META_FILENAME}', + Key=os.path.join( + self.server_config['s3_prefix'], name, _MLFLOW_META_FILENAME), )['Body'].read().decode('utf-8')) else: with open(mlflow_meta_path, 'r') as metafile: @@ -372,7 +376,8 @@ def _walk(self, path): elif os.path.isdir(path): return list(os.walk(path)) else: - raise Exception(f'path: {path} is not a valid path to a file or dir.') + raise Exception( + f'path: {path} is not a valid path to a file or dir.') def _copy_files_to_triton_repo(self, artifact_path, name, flavor): copy_paths = self._get_copy_paths(artifact_path, name, flavor) @@ -385,17 +390,19 @@ def _copy_files_to_triton_repo(self, artifact_path, name, flavor): if flavor == "onnx": s3_path = os.path.join( - copy_paths[key]['to'].replace( - self.server_config['triton_model_repo'], ''), - filename, - ).replace('/', '', 1) + self.server_config['s3_prefix'], + copy_paths[key]['to'].replace( + self.server_config['triton_model_repo'], '').strip('/'), + filename, + ) elif flavor == "triton": rel_path = os.path.relpath( - local_path, - copy_paths[key]['from'], - ) - s3_path = f'{name}/{rel_path}' + local_path, + copy_paths[key]['from'], + ) + s3_path = os.path.join( + self.server_config['s3_prefix'], name, rel_path) self.server_config['s3'].upload_file( local_path, @@ -406,7 +413,8 @@ def _copy_files_to_triton_repo(self, artifact_path, name, flavor): if os.path.isdir(copy_paths[key]['from']): if os.path.isdir(copy_paths[key]['to']): shutil.rmtree(copy_paths[key]['to']) - shutil.copytree(copy_paths[key]['from'], copy_paths[key]['to']) + shutil.copytree( + copy_paths[key]['from'], copy_paths[key]['to']) else: if not os.path.isdir(copy_paths[key]['to']): os.makedirs(copy_paths[key]['to']) @@ -428,7 +436,6 @@ def _delete_mlflow_meta(self, filepath): elif os.path.isfile(filepath): os.remove(filepath) - def _delete_deployment_files(self, name): triton_deployment_dir = os.path.join(self.triton_model_repo, name) @@ -436,7 +443,7 @@ def _delete_deployment_files(self, name): if 's3' in self.server_config: objs = self.server_config['s3'].list_objects( Bucket=self.server_config['s3_bucket'], - Prefix=name, + Prefix=os.path.join(self.server_config['s3_prefix'], name), ) for key in objs['Contents']: