Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support complex metrics in Vertex Experiments #1698

Merged
merged 21 commits into from
Sep 30, 2022
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
3ca2d28
Experiments complex metrics (#8)
jaycee-li Sep 26, 2022
ee36af9
fix: import error
jaycee-li Sep 26, 2022
c1aa713
Update log_classification_metrics_sample.py
jaycee-li Sep 26, 2022
0ea90bf
Update samples/model-builder/experiment_tracking/log_classification_m…
jaycee-li Sep 26, 2022
33b03ff
Update log_classification_metrics_sample_test.py
jaycee-li Sep 26, 2022
775c8cf
Merge branch 'complex-metrics' of https://github.com/jaycee-li/python…
jaycee-li Sep 26, 2022
5bf0d13
Update samples/model-builder/conftest.py
jaycee-li Sep 26, 2022
df5c5d1
fix: unit test
jaycee-li Sep 26, 2022
4b49fd8
Merge branch 'main' into complex-metrics
jaycee-li Sep 26, 2022
a82194a
fix comments
jaycee-li Sep 28, 2022
9c04a50
fix comments and update google.ClassificationMetrics
jaycee-li Sep 29, 2022
52af4d3
Merge branch 'main' into complex-metrics
jaycee-li Sep 29, 2022
796f196
fix comments and update ClassificationMetrics class
jaycee-li Sep 29, 2022
59648e3
Merge branch 'complex-metrics' of https://github.com/jaycee-li/python…
jaycee-li Sep 29, 2022
fbc98ab
fix: ClassificationMetrics doesn't catch params with value=0
jaycee-li Sep 29, 2022
6fb76ba
add sample for get_classification_metrics
jaycee-li Sep 29, 2022
8b223ee
Merge branch 'main' into complex-metrics
jaycee-li Sep 29, 2022
96ef2f3
fix linting
jaycee-li Sep 29, 2022
aefc4ad
Merge branch 'complex-metrics' of https://github.com/jaycee-li/python…
jaycee-li Sep 29, 2022
98cd805
add todos
jaycee-li Sep 29, 2022
734893c
Merge branch 'main' into complex-metrics
jaycee-li Sep 30, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions google/cloud/aiplatform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@

log_params = metadata.metadata._experiment_tracker.log_params
log_metrics = metadata.metadata._experiment_tracker.log_metrics
log_classification_metrics = (
metadata.metadata._experiment_tracker.log_classification_metrics
)
get_experiment_df = metadata.metadata._experiment_tracker.get_experiment_df
start_run = metadata.metadata._experiment_tracker.start_run
start_execution = metadata.metadata._experiment_tracker.start_execution
Expand All @@ -110,6 +113,7 @@
"log",
"log_params",
"log_metrics",
"log_classification_metrics",
"log_time_series_metrics",
"get_experiment_df",
"get_pipeline_df",
Expand Down
165 changes: 165 additions & 0 deletions google/cloud/aiplatform/metadata/experiment_run_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@
from google.cloud.aiplatform.metadata import metadata
from google.cloud.aiplatform.metadata import resource
from google.cloud.aiplatform.metadata import utils as metadata_utils
from google.cloud.aiplatform.metadata.schema import utils as schema_utils
from google.cloud.aiplatform.metadata.schema.google import (
artifact_schema as google_artifact_schema,
)
from google.cloud.aiplatform.tensorboard import tensorboard_resource
from google.cloud.aiplatform.utils import rest_utils

Expand Down Expand Up @@ -990,6 +994,108 @@ def log_metrics(self, metrics: Dict[str, Union[float, int, str]]):
# TODO: query the latest metrics artifact resource before logging.
self._metadata_node.update(metadata={constants._METRIC_KEY: metrics})

@_v1_not_supported
def log_classification_metrics(
jaycee-li marked this conversation as resolved.
Show resolved Hide resolved
self,
*,
labels: Optional[List[str]] = None,
matrix: Optional[List[List[int]]] = None,
jaycee-li marked this conversation as resolved.
Show resolved Hide resolved
fpr: Optional[List[float]] = None,
tpr: Optional[List[float]] = None,
threshold: Optional[List[float]] = None,
jaycee-li marked this conversation as resolved.
Show resolved Hide resolved
display_name: Optional[str] = None,
):
"""Create an artifact for classification metrics and log to ExperimentRun. Currently supports confusion matrix and ROC curve.

```
my_run = aiplatform.ExperimentRun('my-run', experiment='my-experiment')
my_run.log_classification_metrics(
display_name='my-classification-metrics',
labels=['cat', 'dog'],
matrix=[[9, 1], [1, 9]],
fpr=[0.1, 0.5, 0.9],
tpr=[0.1, 0.7, 0.9],
threshold=[0.9, 0.5, 0.1],
)
```

Args:
labels (List[str]):
Optional. List of label names for the confusion matrix. Must be set if 'matrix' is set.
matrix (List[List[int]):
Optional. Values for the confusion matrix. Must be set if 'labels' is set.
fpr (List[float]):
Optional. List of false positive rates for the ROC curve. Must be set if 'tpr' or 'thresholds' is set.
tpr (List[float]):
Optional. List of true positive rates for the ROC curve. Must be set if 'fpr' or 'thresholds' is set.
threshold (List[float]):
Optional. List of thresholds for the ROC curve. Must be set if 'fpr' or 'tpr' is set.
display_name (str):
Optional. The user-defined name for the classification metric artifact.

Raises:
ValueError: if 'labels' and 'matrix' are not set together
or if 'labels' and 'matrix' are not in the same length
or if 'fpr' and 'tpr' and 'threshold' are not set together
or if 'fpr' and 'tpr' and 'threshold' are not in the same length
"""
if (labels or matrix) and not (labels and matrix):
raise ValueError("labels and matrix must be set together.")

if (fpr or tpr or threshold) and not (fpr and tpr and threshold):
raise ValueError("fpr, tpr, and thresholds must be set together.")

if labels and matrix:
if len(matrix) != len(labels):
raise ValueError(
"Length of labels and matrix must be the same. "
"Got lengths {} and {} respectively.".format(
len(labels), len(matrix)
)
)
annotation_specs = [
schema_utils.AnnotationSpec(display_name=label) for label in labels
]
confusion_matrix = schema_utils.ConfusionMatrix(
annotation_specs=annotation_specs,
matrix=matrix,
)

if fpr and tpr and threshold:
if (
len(fpr) != len(tpr)
or len(fpr) != len(threshold)
or len(tpr) != len(threshold)
):
raise ValueError(
"Length of fpr, tpr and threshold must be the same. "
"Got lengths {}, {} and {} respectively.".format(
len(fpr), len(tpr), len(threshold)
)
)

confidence_metrics = [
schema_utils.ConfidenceMetric(
confidence_threshold=confidence_threshold,
false_positive_rate=false_positive_rate,
recall=recall,
)
for confidence_threshold, false_positive_rate, recall in zip(
threshold, fpr, tpr
)
]

classification_metrics = google_artifact_schema.ClassificationMetrics(
display_name=display_name,
confusion_matrix=confusion_matrix,
confidence_metrics=confidence_metrics,
)

classfication_metrics = classification_metrics.create()
self._metadata_node.add_artifacts_and_executions(
artifact_resource_names=[classfication_metrics.resource_name]
)

@_v1_not_supported
def get_time_series_data_frame(self) -> "pd.DataFrame": # noqa: F821
"""Returns all time series in this Run as a DataFrame.
Expand Down Expand Up @@ -1149,6 +1255,65 @@ def get_metrics(self) -> Dict[str, Union[float, int, str]]:
else:
return self._metadata_node.metadata[constants._METRIC_KEY]

@_v1_not_supported
def get_classification_metrics(self) -> List[Dict[str, Union[str, List]]]:
jaycee-li marked this conversation as resolved.
Show resolved Hide resolved
"""Get all the classification metrics logged to this run.

```
my_run = aiplatform.ExperimentRun('my-run', experiment='my-experiment')
metric = my_run.get_classification_metrics()[0]
print(metric)
## print result:
{
"id": "e6c893a4-222e-4c60-a028-6a3b95dfc109",
"display_name": "my-classification-metrics",
"labels": ["cat", "dog"],
"matrix": [[9,1], [1,9]],
"fpr": [0.1, 0.5, 0.9],
"tpr": [0.1, 0.7, 0.9],
"thresholds": [0.9, 0.5, 0.1]
}
```

Returns:
List of classification metrics logged to this experiment run.
"""

artifact_list = artifact.Artifact.list(
filter=metadata_utils._make_filter_string(
in_context=[self.resource_name],
schema_title=google_artifact_schema.ClassificationMetrics.schema_title,
),
project=self.project,
location=self.location,
credentials=self.credentials,
)

metrics = []
for metric_artifact in artifact_list:
metric = {}
metric["id"] = metric_artifact.name
metric["display_name"] = metric_artifact.display_name
metadata = metric_artifact.metadata
if "confusionMatrix" in metadata:
metric["labels"] = [
d["displayName"]
for d in metadata["confusionMatrix"]["annotationSpecs"]
]
metric["matrix"] = metadata["confusionMatrix"]["rows"]

if "confidenceMetrics" in metadata:
metric["fpr"] = [
d["falsePositiveRate"] for d in metadata["confidenceMetrics"]
]
metric["tpr"] = [d["recall"] for d in metadata["confidenceMetrics"]]
metric["threshold"] = [
d["confidenceThreshold"] for d in metadata["confidenceMetrics"]
]
metrics.append(metric)

return metrics

@_v1_not_supported
def associate_execution(self, execution: execution.Execution):
"""Associate an execution to this experiment run.
Expand Down
59 changes: 57 additions & 2 deletions google/cloud/aiplatform/metadata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
# limitations under the License.
#


from typing import Dict, Union, Optional, Any
from typing import Dict, Union, Optional, Any, List

from google.api_core import exceptions
from google.auth import credentials as auth_credentials
Expand Down Expand Up @@ -371,6 +370,62 @@ def log_metrics(self, metrics: Dict[str, Union[float, int, str]]):
# query the latest metrics artifact resource before logging.
self._experiment_run.log_metrics(metrics=metrics)

def log_classification_metrics(
self,
*,
labels: Optional[List[str]] = None,
matrix: Optional[List[List[int]]] = None,
fpr: Optional[List[float]] = None,
tpr: Optional[List[float]] = None,
threshold: Optional[List[float]] = None,
display_name: Optional[str] = None,
):
"""Create an artifact for classification metrics and log to ExperimentRun. Currently support confusion matrix and ROC curve.

```
my_run = aiplatform.ExperimentRun('my-run', experiment='my-experiment')
my_run.log_classification_metrics(
display_name='my-classification-metrics',
labels=['cat', 'dog'],
matrix=[[9, 1], [1, 9]],
fpr=[0.1, 0.5, 0.9],
tpr=[0.1, 0.7, 0.9],
threshold=[0.9, 0.5, 0.1],
)
```

Args:
labels (List[str]):
Optional. List of label names for the confusion matrix. Must be set if 'matrix' is set.
matrix (List[List[int]):
Optional. Values for the confusion matrix. Must be set if 'labels' is set.
fpr (List[float]):
Optional. List of false positive rates for the ROC curve. Must be set if 'tpr' or 'thresholds' is set.
tpr (List[float]):
Optional. List of true positive rates for the ROC curve. Must be set if 'fpr' or 'thresholds' is set.
threshold (List[float]):
Optional. List of thresholds for the ROC curve. Must be set if 'fpr' or 'tpr' is set.
display_name (str):
Optional. The user-defined name for the classification metric artifact.

Raises:
ValueError: if 'labels' and 'matrix' are not set together
or if 'labels' and 'matrix' are not in the same length
or if 'fpr' and 'tpr' and 'threshold' are not set together
or if 'fpr' and 'tpr' and 'threshold' are not in the same length
"""

self._validate_experiment_and_run(method_name="log_classification_metrics")
# query the latest metrics artifact resource before logging.
self._experiment_run.log_classification_metrics(
display_name=display_name,
labels=labels,
matrix=matrix,
fpr=fpr,
tpr=tpr,
threshold=threshold,
)

def _validate_experiment_and_run(self, method_name: str):
"""Validates Experiment and Run are set and raises informative error message.

Expand Down
64 changes: 60 additions & 4 deletions google/cloud/aiplatform/metadata/schema/google/artifact_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.

import copy
from typing import Optional, Dict
from typing import Optional, Dict, List

from google.cloud.aiplatform.compat.types import artifact as gca_artifact
from google.cloud.aiplatform.metadata.schema import base_artifact
Expand All @@ -24,6 +24,12 @@
# The artifact property key for the resource_name
_ARTIFACT_PROPERTY_KEY_RESOURCE_NAME = "resourceName"

_CLASSIFICATION_METRICS_AGGREGATION_TYPE = [
"AGGREGATION_TYPE_UNSPECIFIED",
"MACRO_AVERAGE",
"MICRO_AVERAGE",
]


class VertexDataset(base_artifact.BaseArtifactSchema):
"""An artifact representing a Vertex Dataset."""
Expand Down Expand Up @@ -278,9 +284,17 @@ class ClassificationMetrics(base_artifact.BaseArtifactSchema):
def __init__(
self,
*,
aggregation_type: Optional[str] = None,
aggregation_threshold: Optional[float] = None,
recall: Optional[float] = None,
precision: Optional[float] = None,
f1_score: Optional[float] = None,
accuracy: Optional[float] = None,
au_prc: Optional[float] = None,
au_roc: Optional[float] = None,
log_loss: Optional[float] = None,
confusion_matrix: Optional[utils.ConfusionMatrix] = None,
confidence_metrics: Optional[List[utils.ConfidenceMetric]] = None,
artifact_id: Optional[str] = None,
uri: Optional[str] = None,
display_name: Optional[str] = None,
Expand All @@ -290,6 +304,22 @@ def __init__(
state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
):
"""Args:
aggregation_type (str):
Optional. The way to generate the aggregated metrics. Choose from the following options:
"AGGREGATION_TYPE_UNSPECIFIED": Indicating unset, used for per-class sliced metrics
"MACRO_AVERAGE": The unweighted average, default behavior
"MICRO_AVERAGE": The weighted average
aggregation_threshold (float):
Optional. The threshold used to generate aggregated metrics, default 0 for multi-class classification, 0.5 for binary classification.
recall (float):
Optional. Recall (True Positive Rate) for the given confidence threshold.
precision (float):
Optional. Precision for the given confidence threshold.
f1_score (float):
Optional. The harmonic mean of recall and precision.
accuracy (float):
Optional. Accuracy is the fraction of predictions given the correct label.
For multiclass this is a micro-average metric.
au_prc (float):
Optional. The Area Under Precision-Recall Curve metric.
Micro-averaged for the overall evaluation.
Expand All @@ -298,6 +328,10 @@ def __init__(
Micro-averaged for the overall evaluation.
log_loss (float):
Optional. The Log Loss metric.
confusion_matrix (utils.ConfusionMatrix):
Optional. Aggregated confusion matrix.
confidence_metrics (List[utils.ConfidenceMetric]):
Optional. List of metrics for different confidence thresholds.
artifact_id (str):
Optional. The <resource_id> portion of the Artifact name with
the format. This is globally unique in a metadataStore:
Expand All @@ -323,12 +357,34 @@ def __init__(
check the validity of state transitions.
"""
extended_metadata = copy.deepcopy(metadata) if metadata else {}
if au_prc:
if aggregation_type:
if aggregation_type not in _CLASSIFICATION_METRICS_AGGREGATION_TYPE:
jaycee-li marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"aggregation_type can only be 'AGGREGATION_TYPE_UNSPECIFIED', 'MACRO_AVERAGE', or 'MICRO_AVERAGE'."
)
extended_metadata["aggregationType"] = aggregation_type
if aggregation_threshold is not None:
extended_metadata["aggregationThreshold"] = aggregation_threshold
if recall is not None:
extended_metadata["recall"] = recall
if precision is not None:
extended_metadata["precision"] = precision
if f1_score is not None:
extended_metadata["f1Score"] = f1_score
if accuracy is not None:
extended_metadata["accuracy"] = accuracy
if au_prc is not None:
extended_metadata["auPrc"] = au_prc
if au_roc:
if au_roc is not None:
extended_metadata["auRoc"] = au_roc
if log_loss:
if log_loss is not None:
extended_metadata["logLoss"] = log_loss
if confusion_matrix:
extended_metadata["confusionMatrix"] = confusion_matrix.to_dict()
if confidence_metrics:
extended_metadata["confidenceMetrics"] = [
confidence_metric.to_dict() for confidence_metric in confidence_metrics
]

super(ClassificationMetrics, self).__init__(
uri=uri,
Expand Down
Loading