Skip to content

Commit

Permalink
Deprecate CLOUD_PROVIDER and require STORAGE_BUCKET to contain the bu…
Browse files Browse the repository at this point in the history
…cket protocol (#41)

* unify bucket env vars as `STORAGE_BUCKET`

* determine `ModelConfigWriter` based on bucket path

* Remove `CLOUD_PROVIDER` environment variable.

* Deprecate `--cloud-provider` argument, but keep for backward compatibility.

* require `--storage-bucket` as an argparse argument.

* Set default `STORAGE_BUCKET` value to `gs://deepcell-models`.

* Update the README with new 2-container instructions.

* Add `--no-cache-dir` to `pip install` command for the config writer Docker image.
  • Loading branch information
willgraf authored Nov 25, 2020
1 parent e9f4f0b commit cb9cba9
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 40 deletions.
8 changes: 3 additions & 5 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,12 @@ MODEL_CONFIG_FILE=

# Cloud Storage Bucket
MODEL_PREFIX=
CLOUD_PROVIDER=
STORAGE_BUCKET=

# AWS Credentials
AWS_S3_BUCKET=
AWS_ACCESS_KEY_ID=
AWS_SECRET_ACCESS_KEY=

# Google Cloud Credentials
GCLOUD_STORAGE_BUCKET=
# Optional (Can use gcloud CLI to authenticate instead)
# Optional GKE Credentials
# (Can use gcloud CLI to authenticate instead)
GOOGLE_APPLICATION_CREDENTIALS=
22 changes: 11 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,31 @@
[![Coverage Status](https://coveralls.io/repos/github/vanvalenlab/kiosk-tf-serving/badge.svg?branch=master)](https://coveralls.io/github/vanvalenlab/kiosk-tf-serving?branch=master)
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](/LICENSE)

`kiosk-tf-serving` uses [TensorFlow Serving](https://www.tensorflow.org/tfx/guide/serving) to serve deep learning models over gRPC and REST APIs. A configuration file is automatically created on startup which allows any model found in a (AWS or GCS) storage bucket to be served.
`kiosk-tf-serving` uses [TensorFlow Serving](https://www.tensorflow.org/tfx/guide/serving) to serve deep learning models over gRPC and REST APIs. A configuration file can be automatically created using `python write_config_file.py` to allow any model found in a (AWS or GCS) storage bucket to be served.

TensorFlow serving will host all versions of all models in the bucket via RPC and REST APIs.

This repository is part of the [DeepCell Kiosk](https://github.com/vanvalenlab/kiosk-console). More information about the Kiosk project is available through [Read the Docs](https://deepcell-kiosk.readthedocs.io/en/master) and our [FAQ](http://www.deepcell.org/faq) page.

## Docker

Compile the docker container by running
Build the docker image by running

```bash
docker build --pull -t $(whoami)/kiosk-tf-serving .
docker build --pull -t $(whoami)/kiosk-tf-serving -f docker/Dockerfile.server .
```

Run the docker container by running
Run the docker image by running

```bash
NV_GPU='0' nvidia-docker run -it \
--runtime=nvidia \
-e MODEL_PREFIX=models \
# write the configuration files for a given bucket
python write_config_model.py --storage-bucket=$STORAGE_BUCKET

# mount the config files and run the image
docker run --gpus=1 -it \
-v $PWD:/config \
-e PORT=8500 \
-e REST_API_PORT=8501 \
-e CLOUD_PROVIDER=gke \
-e GCLOUD_STORAGE_BUCKET=YOUR_BUCKET_NAME \
-p 8500:8500 \
-p 8501:8501 \
$(whoami)/kiosk-tf-serving:latest
Expand All @@ -39,8 +40,7 @@ The `kiosk-tf-serving` can be configured using environmental variables in a `.en

| Name | Description | Default Value |
| :--- | :--- | :--- |
| `GCLOUD_STORAGE_BUCKET` | **REQUIRED**: Cloud storage bucket address (e.g. `"gs://bucket-name"`). | `""` |
| `CLOUD_PROVIDER` | **REQUIRED**: The cloud provider hosting the DeepCell Kiosk. | `"gke"` |
| `STORAGE_BUCKET` | **REQUIRED**: Cloud storage bucket address (e.g. `"gs://bucket-name"`). | `""` |
| `PORT` | Port to listen on for gRPC API. | `8500` |
| `REST_API_PORT` | Port to listen on for HTTP/REST API. | `8501` |
| `REST_API_TIMEOUT` | Timeout in ms for HTTP/REST API calls. | `30000` |
Expand Down
2 changes: 1 addition & 1 deletion bin/write.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# write the configuration files before running the server
python write_config_file.py \
--cloud-provider=$CLOUD_PROVIDER \
--storage-bucket=$STORAGE_BUCKET \
--model-prefix=$MODEL_PREFIX \
--file-path=$MODEL_CONFIG_FILE \
--monitoring-enabled=$PROMETHEUS_MONITORING_ENABLED \
Expand Down
4 changes: 2 additions & 2 deletions docker/Dockerfile.writer
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ FROM python:3.7-slim-buster

WORKDIR /usr/src/app

ENV CLOUD_PROVIDER=aws \
ENV STORAGE_BUCKET=gs://deepcell-models \
MODEL_PREFIX=models \
PROMETHEUS_MONITORING_ENABLED=true \
PROMETHEUS_MONITORING_PATH=/monitoring/prometheus/metrics \
Expand All @@ -41,7 +41,7 @@ ENV CLOUD_PROVIDER=aws \

# Copy requirements.txt and install python dependencies
COPY requirements.txt requirements.txt
RUN pip install -r requirements.txt
RUN pip install -r requirements.txt --no-cache-dir

# Copy python script to generate model configuration file
COPY writers write_config_file.py /usr/src/app/
Expand Down
35 changes: 18 additions & 17 deletions write_config_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def get_arg_parser():
# Model Config args
parser.add_argument('-c', '--cloud-provider',
choices=['aws', 'gke'],
required=True,
help='Cloud Provider')
required=False,
help='DEPRECATED: Cloud Provider')

parser.add_argument('-p', '--model-prefix',
default='/',
Expand All @@ -85,6 +85,10 @@ def get_arg_parser():
default=os.path.join(root_dir, 'models.conf'),
help='Full filepath of configuration file')

parser.add_argument('-b', '--storage-bucket', required=True,
help='Cloud Storage Bucket '
'(e.g. gs://deepcell-models)')

# Batch Config Args
parser.add_argument('--enable-batching', type=bool, default=True,
help='Boolean switch for batching configuration.')
Expand Down Expand Up @@ -119,22 +123,19 @@ def get_arg_parser():

def write_model_config_file(args):
# Create the ConfigWriter based on the cloud provider
if args.cloud_provider.lower() == 'aws':
writer = writers.S3ConfigWriter(
bucket=config('AWS_S3_BUCKET'),
model_prefix=args.model_prefix,
aws_access_key_id=config('AWS_ACCESS_KEY_ID'),
aws_secret_access_key=config('AWS_SECRET_ACCESS_KEY'))

elif args.cloud_provider.lower() == 'gke':
writer = writers.GCSConfigWriter(
bucket=config('GCLOUD_STORAGE_BUCKET'),
model_prefix=args.model_prefix)
writer_cls = writers.get_model_config_writer(args.storage_bucket)

else:
raise ValueError('Expected `cloud_provider` to be one of'
' ["aws", "gke"]. Got {}'.format(
args.cloud_provider))
writerkwargs = {
'bucket': str(args.storage_bucket).split('://')[-1],
'model_prefix': args.model_prefix,
}

# additional AWS required credentials
if isinstance(writer_cls, writers.S3ConfigWriter):
writerkwargs['aws_access_key_id'] = config('AWS_ACCESS_KEY_ID')
writerkwargs['aws_secret_access_key'] = config('AWS_SECRET_ACCESS_KEY')

writer = writer_cls(**writerkwargs)

# Write the config file
writer.write(args.file_path)
Expand Down
1 change: 1 addition & 0 deletions writers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from writers.writers import GCSConfigWriter
from writers.writers import MonitoringConfigWriter
from writers.writers import BatchConfigWriter
from writers.writers import get_model_config_writer

del absolute_import
del division
Expand Down
21 changes: 21 additions & 0 deletions writers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,24 @@ def _get_models_from_bucket(self):
blobs = (b.name for b in bucket.list_blobs(prefix=self.model_prefix))
for model in self._filter_models(blobs):
yield model


def get_model_config_writer(bucket):
"""Based on the bucket address, return the appropriate ConfigWriter class.
Args:
bucket (str): Path of the storage bucket to use.
Returns:
ModelConfigWriter: Class to read the bucket and create a model config.
"""
b = str(bucket).lower()
if b.startswith('s3://'):
return S3ConfigWriter

if b.startswith('gs://'):
return GCSConfigWriter

protocol = b.split('://')[0]
raise ValueError('Unknown bucket protocol "{}" in bucket "{}"'.format(
protocol, b))
20 changes: 16 additions & 4 deletions writers/writers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,19 @@

import pytest

from writers import writers
import writers


def test_get_model_config_writer():
# Test s3:// protocol gets correct class
writer_cls = writers.get_model_config_writer('s3://bucket-name/model-path')
assert writer_cls is writers.S3ConfigWriter

writer_cls = writers.get_model_config_writer('gs://bucket-name/model-path')
assert writer_cls is writers.GCSConfigWriter

with pytest.raises(ValueError):
writers.get_model_config_writer('abc://bucket-name/model-path')


class TestConfigWriter(object):
Expand Down Expand Up @@ -160,7 +172,8 @@ class TestModelConfigWriter(object):
def _get_writer(self):
bucket = 'test-bucket'
prefix = 'models'
return writers.ModelConfigWriter(bucket, prefix, protocol='test')
return writers.writers.ModelConfigWriter(
bucket, prefix, protocol='test')

def test_get_model_url(self):
writer = self._get_writer()
Expand Down Expand Up @@ -339,8 +352,7 @@ def list_blobs(self, prefix):
# test correctness
with open(path) as f:
content = f.readlines()
import warnings
warnings.warn('%s' % ''.join(content))

assert content[0] == 'model_config_list: {\n'
assert len(content) == N * 8 + 2
clean = lambda x: x.replace(' ', '').replace('\n', '')
Expand Down

0 comments on commit cb9cba9

Please sign in to comment.