Skip to content

Commit

Permalink
Merge branch 'main' into owl-bot-copy
Browse files Browse the repository at this point in the history
  • Loading branch information
nayaknishant authored Nov 16, 2023
2 parents 26253e4 + 6c1f2cc commit 2d2a68c
Show file tree
Hide file tree
Showing 11 changed files with 377 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,14 @@ def _create(
"contentsDeltaUri": contents_delta_uri,
},
index_update_method=index_update_method_enum,
encryption_spec=gca_encryption_spec.EncryptionSpec(
kms_key_name=encryption_spec_key_name
),
)

if encryption_spec_key_name:
encryption_spec = gca_encryption_spec.EncryptionSpec(
kms_key_name=encryption_spec_key_name
)
gapic_index.encryption_spec = encryption_spec

if labels:
utils.validate_labels(labels)
gapic_index.labels = labels
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,91 @@ class Namespace:
deny_tokens: list = field(default_factory=list)


@dataclass
class NumericNamespace:
"""NumericNamespace specifies the rules for determining the datapoints that
are eligible for each matching query, overall query is an AND across namespaces.
This uses numeric comparisons.
Args:
name (str):
Required. The name of this numeric namespace.
value_int (int):
Optional. 64 bit integer value for comparison. Must choose one among
`value_int`, `value_float` and `value_double` for intended
precision.
value_float (float):
Optional. 32 bit float value for comparison. Must choose one among
`value_int`, `value_float` and `value_double` for
intended precision.
value_double (float):
Optional. 64b bit float value for comparison. Must choose one among
`value_int`, `value_float` and `value_double` for
intended precision.
operator (str):
Optional. Should be specified for query only, not for a datapoints.
Specify one operator to use for comparison. Datapoints for which
comparisons with query's values are true for the operator and value
combination will be allowlisted. Choose among:
"LESS" for datapoints' values < query's value;
"LESS_EQUAL" for datapoints' values <= query's value;
"EQUAL" for datapoints' values = query's value;
"GREATER_EQUAL" for datapoints' values >= query's value;
"GREATER" for datapoints' values > query's value;
"""

name: str
value_int: Optional[int] = None
value_float: Optional[float] = None
value_double: Optional[float] = None
op: Optional[str] = None

def __post_init__(self):
"""Check NumericNamespace values are of correct types and values are
not all none.
Args:
None.
Raises:
ValueError: Numeric Namespace provided values must be of correct
types and one of value_int, value_float, value_double must exist.
"""
# Check one of
if (
self.value_int is None
and self.value_float is None
and self.value_double is None
):
raise ValueError(
"Must choose one among `value_int`,"
"`value_float` and `value_double` for "
"intended precision."
)

# Check value type
if self.value_int is not None and not isinstance(self.value_int, int):
raise ValueError(
"value_int must be of type int, got" f" { type(self.value_int)}."
)
if self.value_float is not None and not isinstance(self.value_float, float):
raise ValueError(
"value_float must be of type float, got " f"{ type(self.value_float)}."
)
if self.value_double is not None and not isinstance(self.value_double, float):
raise ValueError(
"value_double must be of type float, got "
f"{ type(self.value_double)}."
)
# Check operator validity
if (
self.op
not in gca_index_v1beta1.IndexDatapoint.NumericRestriction.Operator._member_names_
):
raise ValueError(
f"Invalid operator '{self.op}'," " must be one of the valid operators."
)


class MatchingEngineIndexEndpoint(base.VertexAiResourceNounWithFutureManager):
"""Matching Engine index endpoint resource for Vertex AI."""

Expand Down Expand Up @@ -1034,6 +1119,7 @@ def find_neighbors(
approx_num_neighbors: Optional[int] = None,
fraction_leaf_nodes_to_search_override: Optional[float] = None,
return_full_datapoint: bool = False,
numeric_filter: Optional[List[NumericNamespace]] = [],
) -> List[List[MatchNeighbor]]:
"""Retrieves nearest neighbors for the given embedding queries on the specified deployed index which is deployed to public endpoint.
Expand Down Expand Up @@ -1082,6 +1168,11 @@ def find_neighbors(
Note that returning full datapoint will significantly increase the
latency and cost of the query.
numeric_filter (Optional[list[NumericNamespace]]):
Optional. A list of NumericNamespaces for filtering the matching
results. For example:
[NumericNamespace(name="cost", value_int=5, op="GREATER")]
will match datapoints that its cost is greater than 5.
Returns:
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
"""
Expand Down Expand Up @@ -1110,12 +1201,22 @@ def find_neighbors(
fraction_leaf_nodes_to_search_override
)
datapoint = gca_index_v1beta1.IndexDatapoint(feature_vector=query)
# Token restricts
for namespace in filter:
restrict = gca_index_v1beta1.IndexDatapoint.Restriction()
restrict.namespace = namespace.name
restrict.allow_list.extend(namespace.allow_tokens)
restrict.deny_list.extend(namespace.deny_tokens)
datapoint.restricts.append(restrict)
# Numeric restricts
for numeric_namespace in numeric_filter:
numeric_restrict = gca_index_v1beta1.IndexDatapoint.NumericRestriction()
numeric_restrict.namespace = numeric_namespace.name
numeric_restrict.op = numeric_namespace.op
numeric_restrict.value_int = numeric_namespace.value_int
numeric_restrict.value_float = numeric_namespace.value_float
numeric_restrict.value_double = numeric_namespace.value_double
datapoint.numeric_restricts.append(numeric_restrict)
find_neighbors_query.datapoint = datapoint
find_neighbors_request.queries.append(find_neighbors_query)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,15 @@ def persistent_resource_to_cluster(
Returns:
Cluster.
"""
dashboard_address = persistent_resource.resource_runtime.access_uris.get(
"RAY_DASHBOARD_URI"
)
cluster = Cluster(
cluster_resource_name=persistent_resource.name,
network=persistent_resource.network,
state=persistent_resource.state.name,
labels=persistent_resource.labels,
dashboard_address=dashboard_address,
)
if not persistent_resource.resource_runtime_spec.ray_spec:
# skip PersistentResource without RaySpec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,7 @@ def create_custom_job_with_experiment_autologging_sample(
experiment: str,
experiment_run: Optional[str] = None,
) -> None:
aiplatform.init(project=project, location=location, staging_bucket=staging_bucket)

# Ignore the next two lines of code if the experiment you are using already
# has backing tensorboard instance.
tb_instance = aiplatform.Tensorboard.create()
aiplatform.init(experiment=experiment, experiment_tensorboard=tb_instance)
aiplatform.init(project=project, location=location, staging_bucket=staging_bucket, experiment=experiment)

job = aiplatform.CustomJob.from_local_script(
display_name=display_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@

def test_create_custom_job_with_experiment_autologging_sample(
mock_sdk_init,
mock_create_tensorboard,
mock_get_custom_job_from_local_script,
mock_run_custom_job,
mock_tensorboard,
):
create_custom_job_with_experiment_autologging_sample.create_custom_job_with_experiment_autologging_sample(
project=constants.PROJECT,
Expand All @@ -40,13 +38,7 @@ def test_create_custom_job_with_experiment_autologging_sample(
project=constants.PROJECT,
location=constants.LOCATION,
staging_bucket=constants.STAGING_BUCKET,
)

mock_create_tensorboard.assert_called_once()

mock_sdk_init.assert_any_call(
experiment=constants.EXPERIMENT_NAME,
experiment_tensorboard=mock_tensorboard,
)

mock_get_custom_job_from_local_script.assert_called_once_with(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@ def create_custom_job_with_experiment_sample(
experiment: str,
experiment_run: Optional[str] = None,
) -> None:
aiplatform.init(project=project, location=location, staging_bucket=staging_bucket)
aiplatform.init(
project=project,
location=location,
staging_bucket=staging_bucket,
experiment=experiment
)

job = aiplatform.CustomJob.from_local_script(
display_name=display_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def test_create_custom_job_with_experiment_sample(
project=constants.PROJECT,
location=constants.LOCATION,
staging_bucket=constants.STAGING_BUCKET,
experiment=constants.EXPERIMENT_NAME,
)

mock_get_custom_job_from_local_script.assert_called_once_with(
Expand Down
Loading

0 comments on commit 2d2a68c

Please sign in to comment.