forked from aws/amazon-sagemaker-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update to add the Docker build files (aws#3508)
* Add CatBoost MME BYOC example * formatted * Resolving comment # 1 and 2 * Resolving comment # 1 and 2 * Resolving comment # 4 * Resolving clean up comment * Added comments about CatBoost and usage for MME * Reformatted the jupyter file * Added the container with the relevant py files * Added formatting using Black. Also fixed the comments from the Jupyter file * Added formatting using Black. Also fixed the comments from the Jupyter file * Added formatting using Black. Also fixed the comments from the Jupyter file Co-authored-by: marckarp <[email protected]> Co-authored-by: atqy <[email protected]>
- Loading branch information
1 parent
4355063
commit 2e8c261
Showing
5 changed files
with
191 additions
and
3 deletions.
There are no files selected for viewing
47 changes: 47 additions & 0 deletions
47
advanced_functionality/multi_model_catboost/container/Dockerfile
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
FROM ubuntu:18.04 | ||
|
||
# Set a docker label to advertise multi-model support on the container | ||
LABEL com.amazonaws.sagemaker.capabilities.multi-models=true | ||
# Set a docker label to enable container to use SAGEMAKER_BIND_TO_PORT environment variable if present | ||
LABEL com.amazonaws.sagemaker.capabilities.accept-bind-to-port=true | ||
|
||
# Install necessary dependencies for MMS and SageMaker Inference Toolkit | ||
RUN apt-get update && \ | ||
apt-get -y install --no-install-recommends \ | ||
build-essential \ | ||
ca-certificates \ | ||
openjdk-8-jdk-headless \ | ||
python3-dev \ | ||
curl \ | ||
python3 \ | ||
vim \ | ||
&& rm -rf /var/lib/apt/lists/* \ | ||
&& curl -O https://bootstrap.pypa.io/pip/3.7/get-pip.py \ | ||
&& python3 get-pip.py | ||
|
||
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3 1 | ||
RUN update-alternatives --install /usr/local/bin/pip pip /usr/local/bin/pip3 1 | ||
|
||
# Install MXNet, MMS, and SageMaker Inference Toolkit to set up MMS | ||
RUN pip3 --no-cache-dir install multi-model-server \ | ||
sagemaker-inference \ | ||
retrying \ | ||
catboost \ | ||
pandas | ||
|
||
|
||
# Copy entrypoint script to the image | ||
COPY dockerd-entrypoint.py /usr/local/bin/dockerd-entrypoint.py | ||
RUN chmod +x /usr/local/bin/dockerd-entrypoint.py | ||
RUN echo "vmargs=-XX:-UseContainerSupport" >> /usr/local/lib/python3.6/dist-packages/sagemaker_inference/etc/mme-mms.properties | ||
|
||
RUN mkdir -p /home/model-server/ | ||
|
||
# Copy the default custom service file to handle incoming data and inference requests | ||
COPY model_handler.py /home/model-server/model_handler.py | ||
|
||
# Define an entrypoint script for the docker image | ||
ENTRYPOINT ["python", "/usr/local/bin/dockerd-entrypoint.py"] | ||
|
||
# Define command to be passed to the entrypoint | ||
CMD ["serve"] |
Empty file.
33 changes: 33 additions & 0 deletions
33
advanced_functionality/multi_model_catboost/container/dockerd-entrypoint.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import subprocess | ||
import sys | ||
import shlex | ||
import os | ||
from retrying import retry | ||
from subprocess import CalledProcessError | ||
from sagemaker_inference import model_server | ||
|
||
|
||
def _retry_if_error(exception): | ||
return isinstance(exception, CalledProcessError or OSError) | ||
|
||
|
||
@retry(stop_max_delay=1000 * 50, retry_on_exception=_retry_if_error) | ||
def _start_mms(): | ||
# by default the number of workers per model is 1, but we can configure it through the | ||
# environment variable below if desired. | ||
os.environ["MMS_DEFAULT_WORKERS_PER_MODEL"] = "2" | ||
os.environ["OMP_NUM_THREADS"] = "8" | ||
model_server.start_model_server(handler_service="/home/model-server/model_handler.py:handle") | ||
|
||
|
||
def main(): | ||
if sys.argv[1] == "serve": | ||
_start_mms() | ||
else: | ||
subprocess.check_call(shlex.split(" ".join(sys.argv[1:]))) | ||
|
||
# prevent docker exit | ||
subprocess.call(["tail", "-f", "/dev/null"]) | ||
|
||
|
||
main() |
108 changes: 108 additions & 0 deletions
108
advanced_functionality/multi_model_catboost/container/model_handler.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import os | ||
import json | ||
import sys | ||
import logging | ||
import time | ||
import catboost | ||
from catboost import CatBoostClassifier | ||
import pandas as pd | ||
import io | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
import os | ||
|
||
|
||
class ModelHandler(object): | ||
def __init__(self): | ||
start = time.time() | ||
self.initialized = False | ||
print(f" perf __init__ {(time.time() - start) * 1000} ms") | ||
|
||
def initialize(self, ctx): | ||
start = time.time() | ||
self.device = "cpu" | ||
|
||
properties = ctx.system_properties | ||
self.device = "cpu" | ||
model_dir = properties.get("model_dir") | ||
|
||
print("model_dir {}".format(model_dir)) | ||
print(os.system("ls {}".format(model_dir))) | ||
|
||
model_file = CatBoostClassifier() | ||
|
||
onlyfiles = [ | ||
f | ||
for f in os.listdir(model_dir) | ||
if os.path.isfile(os.path.join(model_dir, f)) and f.endswith(".bin") | ||
] | ||
print( | ||
f"Modelhandler:model_file location::{model_dir}:: files:bin:={onlyfiles} :: going to load the first one::" | ||
) | ||
self.model = model_file = model_file.load_model(onlyfiles[0]) | ||
|
||
self.initialized = True | ||
print(f" perf initialize {(time.time() - start) * 1000} ms") | ||
|
||
def preprocess(self, input_data): | ||
""" | ||
Pre-process the request | ||
""" | ||
|
||
start = time.time() | ||
print(type(input_data)) | ||
output = input_data | ||
print(f" perf preprocess {(time.time() - start) * 1000} ms") | ||
return output | ||
|
||
def inference(self, inputs): | ||
""" | ||
Make the inference request against the laoded model | ||
""" | ||
start = time.time() | ||
|
||
predictions = self.model.predict_proba(inputs) | ||
print(f" perf inference {(time.time() - start) * 1000} ms") | ||
return predictions | ||
|
||
def postprocess(self, inference_output): | ||
""" | ||
Post-process the request | ||
""" | ||
|
||
start = time.time() | ||
inference_output = dict(enumerate(inference_output.flatten(), 0)) | ||
print(f" perf postprocess {(time.time() - start) * 1000} ms") | ||
return [inference_output] | ||
|
||
def handle(self, data, context): | ||
""" | ||
Call pre-process, inference and post-process functions | ||
:param data: input data | ||
:param context: mms context | ||
""" | ||
start = time.time() | ||
|
||
input_data = data[0]["body"].decode() | ||
df = pd.read_csv(io.StringIO(input_data)) | ||
|
||
model_input = self.preprocess(df) | ||
model_output = self.inference(model_input) | ||
print(f" perf handle in {(time.time() - start) * 1000} ms") | ||
return self.postprocess(model_output) | ||
|
||
|
||
_service = ModelHandler() | ||
|
||
|
||
def handle(data, context): | ||
start = time.time() | ||
if not _service.initialized: | ||
_service.initialize(context) | ||
|
||
if data is None: | ||
return None | ||
|
||
print(f" perf handle_out {(time.time() - start) * 1000} ms") | ||
return _service.handle(data, context) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters