Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MLFlow Triton Plugin: Add support for s3 prefix and custom endpoint URL #5686

Merged
4 changes: 2 additions & 2 deletions deploy/mlflow-triton-plugin/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ python setup.py install
## Quick Start

In this documentation, we will use the files in `examples` to showcase how
the plugin interacts with Triton Infernce Server. The `onnx_float32_int32_int32`
the plugin interacts with Triton Inference Server. The `onnx_float32_int32_int32`
model in `examples` is a simple model that takes two float32 inputs, INPUT0 and
INPUT1, with shape [-1, 16], and produces two int32 outputs, OUTPUT0 and
OUTPUT1, where OUTPUT0 is the element-wise summation of INPUT0 and INPUT1 and
Expand Down Expand Up @@ -85,7 +85,7 @@ The MLFlow ONNX built-in functionalities can be used to publish `onnx` flavor
models to MLFlow directly, and the MLFlow Triton plugin will prepare the model
to the format expected by Triton. You may also log
[`config.pbtxt`](https://github.com/triton-inference-server/server/blob/main/docs/protocol/extension_model_configuration.md)
as additonal artifact which Triton will be used to serve the model. Otherwise,
as additional artifact which Triton will be used to serve the model. Otherwise,
the server should be run with auto-complete feature enabled
(`--strict-model-config=false`) to generate the model configuration.

Expand Down
101 changes: 99 additions & 2 deletions deploy/mlflow-triton-plugin/mlflow_triton/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os
import re
from collections import namedtuple
from mlflow.exceptions import MlflowException


class Config(dict):
Expand All @@ -34,7 +37,101 @@ def __init__(self):
self['triton_model_repo'] = os.environ.get('TRITON_MODEL_REPO')

if self['triton_model_repo'].startswith('s3://'):
self.s3_regex = re.compile(
's3://(http://|https://|)([0-9a-zA-Z\\-.]+):([0-9]+)/'
'([0-9a-z.\\-]+)(((/[0-9a-zA-Z.\\-_]+)*)?)')

uri = self.parse_path(self['triton_model_repo'])
if uri.protocol == "https://":
protocol = "https://"
else:
protocol = "http://"
endpoint_url = None
if uri.host_name != "" and uri.host_port != "":
endpoint_url = '{}{}:{}'.format(
protocol, uri.host_name, uri.host_port)

import boto3
self['s3'] = boto3.client('s3') # boto3 handles AWS credentials
self['s3_bucket'] = self['triton_model_repo'].replace('s3://', '').replace('/', '')
# boto3 handles AWS credentials
self['s3'] = boto3.client(
's3', endpoint_url=endpoint_url)
self['s3_bucket'] = uri.bucket
self['s3_prefix'] = uri.prefix
self['triton_model_repo'] = 's3://{}'.format(
os.path.join(uri.bucket, uri.prefix))

def parse_path(self, path):
# Cleanup extra slashes
clean_path = self.clean_path(path)

# Get the bucket name and the object path. Return error if path is malformed
match = self.s3_regex.fullmatch(clean_path)
S3URI = namedtuple(
"S3URI", ["protocol", "host_name", "host_port", "bucket", "prefix"])
if match:
uri = S3URI(*match.group(1, 2, 3, 4, 5))
if uri.prefix and uri.prefix[0] == '/':
uri = uri._replace(prefix=uri.prefix[1:])
else:
bucket_start = clean_path.find("s3://") + len("s3://")
bucket_end = clean_path.find("/", bucket_start)

# If there isn't a slash, the address has only the bucket
if bucket_end > bucket_start:
bucket = clean_path[bucket_start:bucket_end]
prefix = clean_path[bucket_end + 1:]
else:
bucket = clean_path[bucket_start:]
prefix = ""
uri = S3URI("", "", "", bucket, prefix)

if uri.bucket == "":
raise MlflowException("No bucket name found in path: " + path)

return uri

def clean_path(self, s3_path):
# Must handle paths with s3 prefix
start = s3_path.find("s3://")
path = ""
if start != -1:
path = s3_path[start + len("s3://"):]
clean_path = ("s3://")
else:
path = s3_path
clean_path = ""

# Must handle paths with https:// or http:// prefix
https_start = path.find("https://")
if https_start != -1:
path = path[https_start + len("https://"):]
clean_path += "https://"
else:
http_start = path.find("http://")
if http_start != -1:
path = path[http_start + len("http://"):]
clean_path += "http://"

# Remove trailing slashes
rtrim_length = len(path.rstrip('/'))
if rtrim_length == 0:
raise MlflowException("Invalid bucket name: '" + path + "'")

# Remove leading slashes
ltrim_length = len(path) - len(path.lstrip('/'))
if ltrim_length == len(path):
raise MlflowException("Invalid bucket name: '" + path + "'")

# Remove extra internal slashes
true_path = path[ltrim_length:rtrim_length + 1]
previous_slash = False
for i in range(len(true_path)):
if true_path[i] == '/':
if not previous_slash:
clean_path += true_path[i]
previous_slash = True
else:
clean_path += true_path[i]
previous_slash = False

return clean_path
43 changes: 25 additions & 18 deletions deploy/mlflow-triton-plugin/mlflow_triton/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def __init__(self, uri):
super(TritonPlugin, self).__init__(target_uri=uri)
self.server_config = Config()
triton_url, self.triton_model_repo = self._get_triton_server_config()
self.supported_flavors = ['triton', 'onnx'] # need to add other flavors
# need to add other flavors
self.supported_flavors = ['triton', 'onnx']
# URL cleaning for constructing Triton client
ssl = False
if triton_url.startswith("http://"):
Expand Down Expand Up @@ -189,8 +190,9 @@ def list_deployments(self):
if 's3' in self.server_config:
meta_dict = ast.literal_eval(self.server_config['s3'].get_object(
Bucket=self.server_config['s3_bucket'],
Key=d['name'] + f'/{_MLFLOW_META_FILENAME}',
)['Body'].read().decode('utf-8'))
Key=os.path.join(
self.server_config['s3_prefix'], d['name'], _MLFLOW_META_FILENAME),
)['Body'].read().decode('utf-8'))
elif os.path.isfile(mlflow_meta_path):
meta_dict = self._get_mlflow_meta_dict(d['name'])
else:
Expand Down Expand Up @@ -278,11 +280,12 @@ def _generate_mlflow_meta_file(self, name, flavor, model_uri):
self.server_config['s3'].put_object(
Body=json.dumps(meta_dict, indent=4).encode('utf-8'),
Bucket=self.server_config["s3_bucket"],
Key=f'{name}/{_MLFLOW_META_FILENAME}',
Key=os.path.join(
self.server_config['s3_prefix'], name, _MLFLOW_META_FILENAME),
)
else:
with open(os.path.join(triton_deployment_dir, _MLFLOW_META_FILENAME),
"w") as outfile:
"w") as outfile:
json.dump(meta_dict, outfile, indent=4)

print("Saved", _MLFLOW_META_FILENAME, "to", triton_deployment_dir)
Expand All @@ -294,7 +297,8 @@ def _get_mlflow_meta_dict(self, name):
if 's3' in self.server_config:
mlflow_meta_dict = ast.literal_eval(self.server_config['s3'].get_object(
Bucket=self.server_config['s3_bucket'],
Key=f'{name}/{_MLFLOW_META_FILENAME}',
Key=os.path.join(
self.server_config['s3_prefix'], name, _MLFLOW_META_FILENAME),
)['Body'].read().decode('utf-8'))
else:
with open(mlflow_meta_path, 'r') as metafile:
Expand Down Expand Up @@ -372,7 +376,8 @@ def _walk(self, path):
elif os.path.isdir(path):
return list(os.walk(path))
else:
raise Exception(f'path: {path} is not a valid path to a file or dir.')
raise Exception(
f'path: {path} is not a valid path to a file or dir.')

def _copy_files_to_triton_repo(self, artifact_path, name, flavor):
copy_paths = self._get_copy_paths(artifact_path, name, flavor)
Expand All @@ -385,17 +390,19 @@ def _copy_files_to_triton_repo(self, artifact_path, name, flavor):

if flavor == "onnx":
s3_path = os.path.join(
copy_paths[key]['to'].replace(
self.server_config['triton_model_repo'], ''),
filename,
).replace('/', '', 1)
self.server_config['s3_prefix'],
copy_paths[key]['to'].replace(
self.server_config['triton_model_repo'], '').strip('/'),
filename,
)

elif flavor == "triton":
rel_path = os.path.relpath(
local_path,
copy_paths[key]['from'],
)
s3_path = f'{name}/{rel_path}'
local_path,
copy_paths[key]['from'],
)
s3_path = os.path.join(
self.server_config['s3_prefix'], name, rel_path)

self.server_config['s3'].upload_file(
local_path,
Expand All @@ -406,7 +413,8 @@ def _copy_files_to_triton_repo(self, artifact_path, name, flavor):
if os.path.isdir(copy_paths[key]['from']):
if os.path.isdir(copy_paths[key]['to']):
shutil.rmtree(copy_paths[key]['to'])
shutil.copytree(copy_paths[key]['from'], copy_paths[key]['to'])
shutil.copytree(
copy_paths[key]['from'], copy_paths[key]['to'])
else:
if not os.path.isdir(copy_paths[key]['to']):
os.makedirs(copy_paths[key]['to'])
Expand All @@ -428,15 +436,14 @@ def _delete_mlflow_meta(self, filepath):
elif os.path.isfile(filepath):
os.remove(filepath)


def _delete_deployment_files(self, name):

triton_deployment_dir = os.path.join(self.triton_model_repo, name)

if 's3' in self.server_config:
objs = self.server_config['s3'].list_objects(
Bucket=self.server_config['s3_bucket'],
Prefix=name,
Prefix=os.path.join(self.server_config['s3_prefix'], name),
)

for key in objs['Contents']:
Expand Down