Skip to content

Commit

Permalink
feat: Support complex metrics in Vertex Experiments (#1698)
Browse files Browse the repository at this point in the history
* Experiments complex metrics (#8)

* feat: new class and API for metrics

* update system test

* update high level log method

* fix system test

* update example

* change from system schema to google schema

* fix: import error

* Update log_classification_metrics_sample.py

* Update samples/model-builder/experiment_tracking/log_classification_metrics_sample.py

Co-authored-by: Dan Lee <[email protected]>

* Update log_classification_metrics_sample_test.py

* Update samples/model-builder/conftest.py

Co-authored-by: Dan Lee <[email protected]>

* fix: unit test

* fix comments

* fix comments and update google.ClassificationMetrics

* fix comments and update ClassificationMetrics class

* fix: ClassificationMetrics doesn't catch params with value=0

* add sample for get_classification_metrics

* fix linting

* add todos

Co-authored-by: Dan Lee <[email protected]>
  • Loading branch information
jaycee-li and dandhlee authored Sep 30, 2022
1 parent 5fe515c commit ed0492e
Show file tree
Hide file tree
Showing 14 changed files with 810 additions and 15 deletions.
4 changes: 4 additions & 0 deletions google/cloud/aiplatform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@

log_params = metadata.metadata._experiment_tracker.log_params
log_metrics = metadata.metadata._experiment_tracker.log_metrics
log_classification_metrics = (
metadata.metadata._experiment_tracker.log_classification_metrics
)
get_experiment_df = metadata.metadata._experiment_tracker.get_experiment_df
start_run = metadata.metadata._experiment_tracker.start_run
start_execution = metadata.metadata._experiment_tracker.start_execution
Expand All @@ -110,6 +113,7 @@
"log",
"log_params",
"log_metrics",
"log_classification_metrics",
"log_time_series_metrics",
"get_experiment_df",
"get_pipeline_df",
Expand Down
165 changes: 165 additions & 0 deletions google/cloud/aiplatform/metadata/experiment_run_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@
from google.cloud.aiplatform.metadata import metadata
from google.cloud.aiplatform.metadata import resource
from google.cloud.aiplatform.metadata import utils as metadata_utils
from google.cloud.aiplatform.metadata.schema import utils as schema_utils
from google.cloud.aiplatform.metadata.schema.google import (
artifact_schema as google_artifact_schema,
)
from google.cloud.aiplatform.tensorboard import tensorboard_resource
from google.cloud.aiplatform.utils import rest_utils

Expand Down Expand Up @@ -990,6 +994,108 @@ def log_metrics(self, metrics: Dict[str, Union[float, int, str]]):
# TODO: query the latest metrics artifact resource before logging.
self._metadata_node.update(metadata={constants._METRIC_KEY: metrics})

@_v1_not_supported
def log_classification_metrics(
self,
*,
labels: Optional[List[str]] = None,
matrix: Optional[List[List[int]]] = None,
fpr: Optional[List[float]] = None,
tpr: Optional[List[float]] = None,
threshold: Optional[List[float]] = None,
display_name: Optional[str] = None,
):
"""Create an artifact for classification metrics and log to ExperimentRun. Currently supports confusion matrix and ROC curve.
```
my_run = aiplatform.ExperimentRun('my-run', experiment='my-experiment')
my_run.log_classification_metrics(
display_name='my-classification-metrics',
labels=['cat', 'dog'],
matrix=[[9, 1], [1, 9]],
fpr=[0.1, 0.5, 0.9],
tpr=[0.1, 0.7, 0.9],
threshold=[0.9, 0.5, 0.1],
)
```
Args:
labels (List[str]):
Optional. List of label names for the confusion matrix. Must be set if 'matrix' is set.
matrix (List[List[int]):
Optional. Values for the confusion matrix. Must be set if 'labels' is set.
fpr (List[float]):
Optional. List of false positive rates for the ROC curve. Must be set if 'tpr' or 'thresholds' is set.
tpr (List[float]):
Optional. List of true positive rates for the ROC curve. Must be set if 'fpr' or 'thresholds' is set.
threshold (List[float]):
Optional. List of thresholds for the ROC curve. Must be set if 'fpr' or 'tpr' is set.
display_name (str):
Optional. The user-defined name for the classification metric artifact.
Raises:
ValueError: if 'labels' and 'matrix' are not set together
or if 'labels' and 'matrix' are not in the same length
or if 'fpr' and 'tpr' and 'threshold' are not set together
or if 'fpr' and 'tpr' and 'threshold' are not in the same length
"""
if (labels or matrix) and not (labels and matrix):
raise ValueError("labels and matrix must be set together.")

if (fpr or tpr or threshold) and not (fpr and tpr and threshold):
raise ValueError("fpr, tpr, and thresholds must be set together.")

if labels and matrix:
if len(matrix) != len(labels):
raise ValueError(
"Length of labels and matrix must be the same. "
"Got lengths {} and {} respectively.".format(
len(labels), len(matrix)
)
)
annotation_specs = [
schema_utils.AnnotationSpec(display_name=label) for label in labels
]
confusion_matrix = schema_utils.ConfusionMatrix(
annotation_specs=annotation_specs,
matrix=matrix,
)

if fpr and tpr and threshold:
if (
len(fpr) != len(tpr)
or len(fpr) != len(threshold)
or len(tpr) != len(threshold)
):
raise ValueError(
"Length of fpr, tpr and threshold must be the same. "
"Got lengths {}, {} and {} respectively.".format(
len(fpr), len(tpr), len(threshold)
)
)

confidence_metrics = [
schema_utils.ConfidenceMetric(
confidence_threshold=confidence_threshold,
false_positive_rate=false_positive_rate,
recall=recall,
)
for confidence_threshold, false_positive_rate, recall in zip(
threshold, fpr, tpr
)
]

classification_metrics = google_artifact_schema.ClassificationMetrics(
display_name=display_name,
confusion_matrix=confusion_matrix,
confidence_metrics=confidence_metrics,
)

classfication_metrics = classification_metrics.create()
self._metadata_node.add_artifacts_and_executions(
artifact_resource_names=[classfication_metrics.resource_name]
)

@_v1_not_supported
def get_time_series_data_frame(self) -> "pd.DataFrame": # noqa: F821
"""Returns all time series in this Run as a DataFrame.
Expand Down Expand Up @@ -1149,6 +1255,65 @@ def get_metrics(self) -> Dict[str, Union[float, int, str]]:
else:
return self._metadata_node.metadata[constants._METRIC_KEY]

@_v1_not_supported
def get_classification_metrics(self) -> List[Dict[str, Union[str, List]]]:
"""Get all the classification metrics logged to this run.
```
my_run = aiplatform.ExperimentRun('my-run', experiment='my-experiment')
metric = my_run.get_classification_metrics()[0]
print(metric)
## print result:
{
"id": "e6c893a4-222e-4c60-a028-6a3b95dfc109",
"display_name": "my-classification-metrics",
"labels": ["cat", "dog"],
"matrix": [[9,1], [1,9]],
"fpr": [0.1, 0.5, 0.9],
"tpr": [0.1, 0.7, 0.9],
"thresholds": [0.9, 0.5, 0.1]
}
```
Returns:
List of classification metrics logged to this experiment run.
"""

artifact_list = artifact.Artifact.list(
filter=metadata_utils._make_filter_string(
in_context=[self.resource_name],
schema_title=google_artifact_schema.ClassificationMetrics.schema_title,
),
project=self.project,
location=self.location,
credentials=self.credentials,
)

metrics = []
for metric_artifact in artifact_list:
metric = {}
metric["id"] = metric_artifact.name
metric["display_name"] = metric_artifact.display_name
metadata = metric_artifact.metadata
if "confusionMatrix" in metadata:
metric["labels"] = [
d["displayName"]
for d in metadata["confusionMatrix"]["annotationSpecs"]
]
metric["matrix"] = metadata["confusionMatrix"]["rows"]

if "confidenceMetrics" in metadata:
metric["fpr"] = [
d["falsePositiveRate"] for d in metadata["confidenceMetrics"]
]
metric["tpr"] = [d["recall"] for d in metadata["confidenceMetrics"]]
metric["threshold"] = [
d["confidenceThreshold"] for d in metadata["confidenceMetrics"]
]
metrics.append(metric)

return metrics

@_v1_not_supported
def associate_execution(self, execution: execution.Execution):
"""Associate an execution to this experiment run.
Expand Down
59 changes: 57 additions & 2 deletions google/cloud/aiplatform/metadata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
# limitations under the License.
#


from typing import Dict, Union, Optional, Any
from typing import Dict, Union, Optional, Any, List

from google.api_core import exceptions
from google.auth import credentials as auth_credentials
Expand Down Expand Up @@ -371,6 +370,62 @@ def log_metrics(self, metrics: Dict[str, Union[float, int, str]]):
# query the latest metrics artifact resource before logging.
self._experiment_run.log_metrics(metrics=metrics)

def log_classification_metrics(
self,
*,
labels: Optional[List[str]] = None,
matrix: Optional[List[List[int]]] = None,
fpr: Optional[List[float]] = None,
tpr: Optional[List[float]] = None,
threshold: Optional[List[float]] = None,
display_name: Optional[str] = None,
):
"""Create an artifact for classification metrics and log to ExperimentRun. Currently support confusion matrix and ROC curve.
```
my_run = aiplatform.ExperimentRun('my-run', experiment='my-experiment')
my_run.log_classification_metrics(
display_name='my-classification-metrics',
labels=['cat', 'dog'],
matrix=[[9, 1], [1, 9]],
fpr=[0.1, 0.5, 0.9],
tpr=[0.1, 0.7, 0.9],
threshold=[0.9, 0.5, 0.1],
)
```
Args:
labels (List[str]):
Optional. List of label names for the confusion matrix. Must be set if 'matrix' is set.
matrix (List[List[int]):
Optional. Values for the confusion matrix. Must be set if 'labels' is set.
fpr (List[float]):
Optional. List of false positive rates for the ROC curve. Must be set if 'tpr' or 'thresholds' is set.
tpr (List[float]):
Optional. List of true positive rates for the ROC curve. Must be set if 'fpr' or 'thresholds' is set.
threshold (List[float]):
Optional. List of thresholds for the ROC curve. Must be set if 'fpr' or 'tpr' is set.
display_name (str):
Optional. The user-defined name for the classification metric artifact.
Raises:
ValueError: if 'labels' and 'matrix' are not set together
or if 'labels' and 'matrix' are not in the same length
or if 'fpr' and 'tpr' and 'threshold' are not set together
or if 'fpr' and 'tpr' and 'threshold' are not in the same length
"""

self._validate_experiment_and_run(method_name="log_classification_metrics")
# query the latest metrics artifact resource before logging.
self._experiment_run.log_classification_metrics(
display_name=display_name,
labels=labels,
matrix=matrix,
fpr=fpr,
tpr=tpr,
threshold=threshold,
)

def _validate_experiment_and_run(self, method_name: str):
"""Validates Experiment and Run are set and raises informative error message.
Expand Down
65 changes: 61 additions & 4 deletions google/cloud/aiplatform/metadata/schema/google/artifact_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.

import copy
from typing import Optional, Dict
from typing import Optional, Dict, List

from google.cloud.aiplatform.compat.types import artifact as gca_artifact
from google.cloud.aiplatform.metadata.schema import base_artifact
Expand All @@ -24,6 +24,12 @@
# The artifact property key for the resource_name
_ARTIFACT_PROPERTY_KEY_RESOURCE_NAME = "resourceName"

_CLASSIFICATION_METRICS_AGGREGATION_TYPE = [
"AGGREGATION_TYPE_UNSPECIFIED",
"MACRO_AVERAGE",
"MICRO_AVERAGE",
]


class VertexDataset(base_artifact.BaseArtifactSchema):
"""An artifact representing a Vertex Dataset."""
Expand Down Expand Up @@ -278,9 +284,17 @@ class ClassificationMetrics(base_artifact.BaseArtifactSchema):
def __init__(
self,
*,
aggregation_type: Optional[str] = None,
aggregation_threshold: Optional[float] = None,
recall: Optional[float] = None,
precision: Optional[float] = None,
f1_score: Optional[float] = None,
accuracy: Optional[float] = None,
au_prc: Optional[float] = None,
au_roc: Optional[float] = None,
log_loss: Optional[float] = None,
confusion_matrix: Optional[utils.ConfusionMatrix] = None,
confidence_metrics: Optional[List[utils.ConfidenceMetric]] = None,
artifact_id: Optional[str] = None,
uri: Optional[str] = None,
display_name: Optional[str] = None,
Expand All @@ -290,6 +304,22 @@ def __init__(
state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
):
"""Args:
aggregation_type (str):
Optional. The way to generate the aggregated metrics. Choose from the following options:
"AGGREGATION_TYPE_UNSPECIFIED": Indicating unset, used for per-class sliced metrics
"MACRO_AVERAGE": The unweighted average, default behavior
"MICRO_AVERAGE": The weighted average
aggregation_threshold (float):
Optional. The threshold used to generate aggregated metrics, default 0 for multi-class classification, 0.5 for binary classification.
recall (float):
Optional. Recall (True Positive Rate) for the given confidence threshold.
precision (float):
Optional. Precision for the given confidence threshold.
f1_score (float):
Optional. The harmonic mean of recall and precision.
accuracy (float):
Optional. Accuracy is the fraction of predictions given the correct label.
For multiclass this is a micro-average metric.
au_prc (float):
Optional. The Area Under Precision-Recall Curve metric.
Micro-averaged for the overall evaluation.
Expand All @@ -298,6 +328,10 @@ def __init__(
Micro-averaged for the overall evaluation.
log_loss (float):
Optional. The Log Loss metric.
confusion_matrix (utils.ConfusionMatrix):
Optional. Aggregated confusion matrix.
confidence_metrics (List[utils.ConfidenceMetric]):
Optional. List of metrics for different confidence thresholds.
artifact_id (str):
Optional. The <resource_id> portion of the Artifact name with
the format. This is globally unique in a metadataStore:
Expand All @@ -323,12 +357,35 @@ def __init__(
check the validity of state transitions.
"""
extended_metadata = copy.deepcopy(metadata) if metadata else {}
if au_prc:
if aggregation_type:
if aggregation_type not in _CLASSIFICATION_METRICS_AGGREGATION_TYPE:
## Todo: add negative test case for this
raise ValueError(
"aggregation_type can only be 'AGGREGATION_TYPE_UNSPECIFIED', 'MACRO_AVERAGE', or 'MICRO_AVERAGE'."
)
extended_metadata["aggregationType"] = aggregation_type
if aggregation_threshold is not None:
extended_metadata["aggregationThreshold"] = aggregation_threshold
if recall is not None:
extended_metadata["recall"] = recall
if precision is not None:
extended_metadata["precision"] = precision
if f1_score is not None:
extended_metadata["f1Score"] = f1_score
if accuracy is not None:
extended_metadata["accuracy"] = accuracy
if au_prc is not None:
extended_metadata["auPrc"] = au_prc
if au_roc:
if au_roc is not None:
extended_metadata["auRoc"] = au_roc
if log_loss:
if log_loss is not None:
extended_metadata["logLoss"] = log_loss
if confusion_matrix:
extended_metadata["confusionMatrix"] = confusion_matrix.to_dict()
if confidence_metrics:
extended_metadata["confidenceMetrics"] = [
confidence_metric.to_dict() for confidence_metric in confidence_metrics
]

super(ClassificationMetrics, self).__init__(
uri=uri,
Expand Down
Loading

0 comments on commit ed0492e

Please sign in to comment.