Skip to content

Commit

Permalink
add more docs
Browse files Browse the repository at this point in the history
Signed-off-by: Jeffrey Kinard <[email protected]>
  • Loading branch information
Polber committed Dec 20, 2024
1 parent 17a34c4 commit a8adde3
Showing 1 changed file with 38 additions and 16 deletions.
54 changes: 38 additions & 16 deletions sdks/python/apache_beam/yaml/yaml_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,12 @@ def underlying_handler(self):
@staticmethod
def default_preprocess_fn():
raise ValueError(
'Handler does not implement a default preprocess '
'Model Handler does not implement a default preprocess '
'method. Please define a preprocessing method using the '
'\'preprocess\' tag.')
'\'preprocess\' tag. This is required in most cases because '
'most models will have a different input shape, so the model '
'cannot generalize how the input Row should be transformed. For '
'an example preprocess method, see VertexAIModelHandlerJSONProvider')

def _preprocess_fn_internal(self):
return lambda row: (row, self._preprocess_fn(row))
Expand Down Expand Up @@ -134,17 +137,34 @@ def __init__(
project: str,
location: str,
preprocess: Dict[str, str],
postprocess: Optional[Dict[str, str]] = None,
experiment: Optional[str] = None,
network: Optional[str] = None,
private: bool = False,
min_batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None,
max_batch_duration_secs: Optional[int] = None,
env_vars: Optional[Dict[str, Any]] = None,
postprocess: Optional[Dict[str, str]] = None):
env_vars: Optional[Dict[str, Any]] = None):
"""
ModelHandler for Vertex AI.
This Model Handler can be used with RunInference to load a model hosted
on VertexAI. Every model that is hosted on VertexAI should have three
distinct, required, parameters - `endpoint_id`, `project` and `location`.
These parameters tell the Model Handler how to access the model's endpoint
so that input data can be sent using an API request, and inferences can be
received as a response.
This Model Handler also required a `preprocess` function to be defined.
Preprocessing and Postprocessing are described in more detail in the
RunInference docs:
https://beam.apache.org/releases/yamldoc/current/#runinference
Every model will have a unique input, but all requests should be
JSON-formatted. For example, most language models such as Llama and Gemma
expect a JSON with the key "prompt" (among other optional keys). In Python,
JSON can be expressed as a dictionary.
For example: ::
- type: RunInference
Expand All @@ -159,10 +179,24 @@ def __init__(
preprocess:
callable: 'lambda x: {"prompt": x.prompt, "max_tokens": 50}'
In the above example, which mimics a call to a Llama 3 model hosted on
VertexAI, the preprocess function (in this case a lambda) takes in a Beam
Row with a single field, "prompt", and maps it to a dict with the same
field. It also specifies an optional parameter, "max_tokens", that tells the
model the allowed token size (in this case input + output token size).
Args:
endpoint_id: the numerical ID of the Vertex AI endpoint to query.
project: the GCP project name where the endpoint is deployed.
location: the GCP location where the endpoint is deployed.
preprocess: A python callable, defined either inline, or using a file,
that is invoked on the input row before sending to the model to be
loaded by this ModelHandler. This parameter is required by the
`VertexAIModelHandlerJSON` ModelHandler.
postprocess: A python callable, defined either inline, or using a file,
that is invoked on the PredictionResult output by the ModelHandler
before parsing into the output Beam Row under the field name defined
by the inference_tag.
experiment: Experiment label to apply to the
queries. See
https://cloud.google.com/vertex-ai/docs/experiments/intro-vertex-ai-experiments
Expand All @@ -183,14 +217,6 @@ def __init__(
max_batch_duration_secs: The maximum amount of time to buffer
a batch before emitting; used in streaming contexts.
env_vars: Environment variables.
preprocess: A python callable, defined either inline, or using a file,
that is invoked on the input row before sending to the model to be
loaded by this ModelHandler. This parameter is required by the
`VertexAIModelHandlerJSON` ModelHandler.
postprocess: A python callable, defined either inline, or using a file,
that is invoked on the PredictionResult output by the ModelHandler
before parsing into the output Beam Row under the field name defined
by the inference_tag.
"""

try:
Expand Down Expand Up @@ -222,10 +248,6 @@ def inference_output_type(self):
return RowTypeConstraint.from_fields([('example', Any), ('inference', Any),
('model_id', Optional[str])])

@staticmethod
def default_postprocess_fn():
return lambda x: beam.Row(**x._asdict())


@beam.ptransform.ptransform_fn
def run_inference(
Expand Down

0 comments on commit a8adde3

Please sign in to comment.