Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MLFlow Triton Plugin: Add support for s3 prefix and custom endpoint URL #5686

Merged
101 changes: 99 additions & 2 deletions deploy/mlflow-triton-plugin/mlflow_triton/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,114 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os
import re
from collections import namedtuple
from mlflow.exceptions import MlflowException


class Config(dict):

def parse_path(self, path):
# Cleanup extra slashes
clean_path = self.clean_path(path)

# Get the bucket name and the object path. Return error if path is malformed
match = self.s3_regex.fullmatch(clean_path)
S3URI = namedtuple(
"S3URI", ["protocol", "host_name", "host_port", "bucket", "prefix"])
if match:
uri = S3URI(*match.group(1, 2, 3, 4, 5))
if uri.prefix and uri.prefix[0] == '/':
uri = uri._replace(prefix=uri.prefix[1:])
else:
bucket_start = clean_path.find("s3://") + len("s3://")
bucket_end = clean_path.find("/", bucket_start)

# If there isn't a slash, the address has only the bucket
if bucket_end > bucket_start:
bucket = clean_path[bucket_start:bucket_end]
prefix = clean_path[bucket_end + 1:]
else:
bucket = clean_path[bucket_start:]
prefix = ""
uri = S3URI("", "", "", bucket, prefix)

if uri.bucket == "":
raise MlflowException("No bucket name found in path: " + path)

return uri

def clean_path(self, s3_path):
# Must handle paths with s3 prefix
start = s3_path.find("s3://")
path = ""
if start != -1:
path = s3_path[start + len("s3://"):]
clean_path = ("s3://")
else:
path = s3_path
clean_path = ""

# Must handle paths with https:// or http:// prefix
https_start = path.find("https://")
if https_start != -1:
path = path[https_start + len("https://"):]
clean_path += "https://"
else:
http_start = path.find("http://")
if http_start != -1:
path = path[http_start + len("http://"):]
clean_path += "http://"

# Remove trailing slashes
rtrim_length = len(path.rstrip('/'))
if rtrim_length == 0:
raise MlflowException("Invalid bucket name: '" + path + "'")

# Remove leading slashes
ltrim_length = len(path) - len(path.lstrip('/'))
if ltrim_length == len(path):
raise MlflowException("Invalid bucket name: '" + path + "'")

# Remove extra internal slashes
true_path = path[ltrim_length:rtrim_length + 1]
slash_locations = []
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved Hide resolved
previous_slash = False
for i in range(len(true_path)):
if true_path[i] == '/':
if not previous_slash:
clean_path += true_path[i]
previous_slash = True
else:
clean_path += true_path[i]
previous_slash = False

return clean_path

def __init__(self):
super().__init__()
self.s3_regex = re.compile(
's3://(http://|https://|)([0-9a-zA-Z\\-.]+):([0-9]+)/'
'([0-9a-z.\\-]+)(((/[0-9a-zA-Z.\\-_]+)*)?)')
self['triton_url'] = os.environ.get('TRITON_URL')
self['triton_model_repo'] = os.environ.get('TRITON_MODEL_REPO')

if self['triton_model_repo'].startswith('s3://'):
uri = self.parse_path(self['triton_model_repo'])
if uri.protocol == "https://":
scheme = "https://"
else:
scheme = "http://"
endpoint_url = None
if uri.host_name != "" and uri.host_port != "":
endpoint_url = '{}{}:{}'.format(
scheme, uri.host_name, uri.host_port)
import boto3
self['s3'] = boto3.client('s3') # boto3 handles AWS credentials
self['s3_bucket'] = self['triton_model_repo'].replace('s3://', '').replace('/', '')
# boto3 handles AWS credentials
self['s3'] = boto3.client(
's3', endpoint_url=endpoint_url)
self['s3_bucket'] = uri.bucket
self['s3_prefix'] = uri.prefix
self['triton_model_repo'] = 's3://{}'.format(
'/'.join(filter(None, [uri.bucket, uri.prefix])))
yeahdongcn marked this conversation as resolved.
Show resolved Hide resolved

69 changes: 38 additions & 31 deletions deploy/mlflow-triton-plugin/mlflow_triton/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def __init__(self, uri):
super(TritonPlugin, self).__init__(target_uri=uri)
self.server_config = Config()
triton_url, self.triton_model_repo = self._get_triton_server_config()
self.supported_flavors = ['triton', 'onnx'] # need to add other flavors
# need to add other flavors
self.supported_flavors = ['triton', 'onnx']
# URL cleaning for constructing Triton client
ssl = False
if triton_url.startswith("http://"):
Expand Down Expand Up @@ -189,8 +190,9 @@ def list_deployments(self):
if 's3' in self.server_config:
meta_dict = ast.literal_eval(self.server_config['s3'].get_object(
Bucket=self.server_config['s3_bucket'],
Key=d['name'] + f'/{_MLFLOW_META_FILENAME}',
)['Body'].read().decode('utf-8'))
Key='/'.join(filter(
yeahdongcn marked this conversation as resolved.
Show resolved Hide resolved
None, [self.server_config['s3_prefix'], d['name'], _MLFLOW_META_FILENAME])),
)['Body'].read().decode('utf-8'))
elif os.path.isfile(mlflow_meta_path):
meta_dict = self._get_mlflow_meta_dict(d['name'])

Expand Down Expand Up @@ -276,23 +278,25 @@ def _generate_mlflow_meta_file(self, name, flavor, model_uri):
self.server_config['s3'].put_object(
Body=json.dumps(meta_dict, indent=4).encode('utf-8'),
Bucket=self.server_config["s3_bucket"],
Key=f'{name}/{_MLFLOW_META_FILENAME}',
Key='/'.join(filter(None,
[self.server_config['s3_prefix'], name, _MLFLOW_META_FILENAME])),
)
else:
with open(os.path.join(triton_deployment_dir, _MLFLOW_META_FILENAME),
"w") as outfile:
"w") as outfile:
json.dump(meta_dict, outfile, indent=4)

print("Saved", _MLFLOW_META_FILENAME, "to", triton_deployment_dir)

def _get_mlflow_meta_dict(self, name):
mlflow_meta_path = os.path.join(self.triton_model_repo, name,
_MLFLOW_META_FILENAME)

if 's3' in self.server_config:
mlflow_meta_dict = ast.literal_eval(self.server_config['s3'].get_object(
Bucket=self.server_config['s3_bucket'],
Key=f'{name}/{_MLFLOW_META_FILENAME}',
Key='/'.join(filter(None,
[self.server_config['s3_prefix'], name, _MLFLOW_META_FILENAME]))
)['Body'].read().decode('utf-8'))
else:
with open(mlflow_meta_path, 'r') as metafile:
Expand Down Expand Up @@ -359,64 +363,67 @@ def _get_copy_paths(self, artifact_path, name, flavor):
return copy_paths

def _walk(self, path):
"""Walk a path like os.walk() if path is dir,
"""Walk a path like os.walk() if path is dir,
return file in the expected format otherwise.
:param path: dir or file path

:return: root, dirs, files
"""
if os.path.isfile(path):
return [(os.path.dirname(path), [], [os.path.basename(path)])]
elif os.path.isdir(path):
return list(os.walk(path))
else:
raise Exception(f'path: {path} is not a valid path to a file or dir.')
raise Exception(
f'path: {path} is not a valid path to a file or dir.')

def _copy_files_to_triton_repo(self, artifact_path, name, flavor):
copy_paths = self._get_copy_paths(artifact_path, name, flavor)
for key in copy_paths:
if 's3' in self.server_config:
if 's3' in self.server_config:
# copy model dir to s3 recursively
for root, dirs, files in self._walk(copy_paths[key]['from']):
for filename in files:
local_path = os.path.join(root, filename)

if flavor == "onnx":
s3_path = os.path.join(
copy_paths[key]['to'].replace(
self.server_config['triton_model_repo'], ''),
filename,
).replace('/', '', 1)
copy_paths[key]['to'].replace(
self.server_config['triton_model_repo'], ''),
filename,
).replace('/', '', 1)

elif flavor == "triton":
rel_path = os.path.relpath(
local_path,
copy_paths[key]['from'],
)
s3_path = f'{name}/{rel_path}'

local_path,
copy_paths[key]['from'],
)
s3_path = '/'.join(
filter(None, [self.server_config['s3_prefix'], name, rel_path]))

self.server_config['s3'].upload_file(
local_path,
local_path,
self.server_config['s3_bucket'],
s3_path,
)
else:
if os.path.isdir(copy_paths[key]['from']):
if os.path.isdir(copy_paths[key]['to']):
shutil.rmtree(copy_paths[key]['to'])
shutil.copytree(copy_paths[key]['from'], copy_paths[key]['to'])
shutil.copytree(
copy_paths[key]['from'], copy_paths[key]['to'])
else:
if not os.path.isdir(copy_paths[key]['to']):
os.makedirs(copy_paths[key]['to'])
shutil.copy(copy_paths[key]['from'], copy_paths[key]['to'])

if 's3' not in self.server_config:
if 's3' not in self.server_config:
triton_deployment_dir = os.path.join(self.triton_model_repo, name)
version_folder = os.path.join(triton_deployment_dir, "1")
os.makedirs(version_folder, exist_ok=True)

return copy_paths

def _delete_mlflow_meta(self, filepath):
if 's3' in self.server_config:
self.server_config['s3'].delete_object(
Expand All @@ -426,27 +433,27 @@ def _delete_mlflow_meta(self, filepath):
elif os.path.isfile(filepath):
os.remove(filepath)


def _delete_deployment_files(self, name):

triton_deployment_dir = os.path.join(self.triton_model_repo, name)

if 's3' in self.server_config:
objs = self.server_config['s3'].list_objects(
Bucket=self.server_config['s3_bucket'],
Prefix=name,
Bucket=self.server_config['s3_bucket'],
Prefix='/'.join(filter(None,
[self.server_config['s3_prefix'], name])),
)

for key in objs['Contents']:
key = key['Key']
try:
self.server_config['s3'].delete_object(
Bucket=self.server_config['s3_bucket'],
Key=key,
)
)
except Exception as e:
raise Exception(f'Could not delete {key}: {e}')

else:
# Check if the deployment directory exists
if not os.path.isdir(triton_deployment_dir):
Expand Down