Skip to content

Commit

Permalink
Backend support for persisting profiling/system metrics to (#8884)
Browse files Browse the repository at this point in the history
generic_metrics:
- DB schema changes
- Changes to backend ReportTrialMetrics APIs
  • Loading branch information
azhou-determined committed Mar 26, 2024
1 parent f08b406 commit 9f681af
Show file tree
Hide file tree
Showing 22 changed files with 407 additions and 63 deletions.
4 changes: 3 additions & 1 deletion master/internal/api_checkpoint_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ func TestCheckpointsOnArchivedSteps(t *testing.T) {
})
require.NoError(t, err)

step := int32(i)

for _, group := range []string{
model.ValidationMetricGroup.ToString(),
model.TrainingMetricGroup.ToString(),
Expand All @@ -147,7 +149,7 @@ func TestCheckpointsOnArchivedSteps(t *testing.T) {
Metrics: &trialv1.TrialMetrics{
TrialId: int32(trial.ID),
TrialRunId: int32(trialRunID),
StepsCompleted: int32(i),
StepsCompleted: &step,
Metrics: &commonv1.Metrics{
AvgMetrics: expectedMetrics,
},
Expand Down
22 changes: 15 additions & 7 deletions master/internal/api_trials_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,15 @@ func createTestTrialWithMetrics(
},
}

step := int32(i)

group := model.MetricGroup("mygroup")
_, err := api.ReportTrialMetrics(ctx,
&apiv1.ReportTrialMetricsRequest{
Metrics: &trialv1.TrialMetrics{
TrialId: int32(trial.ID),
TrialRunId: 0,
StepsCompleted: int32(i),
StepsCompleted: &step,
Metrics: trainMetrics,
},
Group: group.ToString(),
Expand All @@ -132,7 +134,7 @@ func createTestTrialWithMetrics(
TrainingMetrics: &trialv1.TrialMetrics{
TrialId: int32(trial.ID),
TrialRunId: 0,
StepsCompleted: int32(i),
StepsCompleted: &step,
Metrics: trainMetrics,
},
})
Expand Down Expand Up @@ -171,7 +173,7 @@ func createTestTrialWithMetrics(
ValidationMetrics: &trialv1.TrialMetrics{
TrialId: int32(trial.ID),
TrialRunId: 0,
StepsCompleted: int32(i),
StepsCompleted: &step,
Metrics: valMetrics,
},
})
Expand Down Expand Up @@ -440,11 +442,12 @@ func TestNonNumericEpochMetric(t *testing.T) {
require.NoError(t, err)

trial, _ := createTestTrial(t, api, curUser)
step := int32(1)
_, err = api.ReportTrialValidationMetrics(ctx, &apiv1.ReportTrialValidationMetricsRequest{
ValidationMetrics: &trialv1.TrialMetrics{
TrialId: int32(trial.ID),
TrialRunId: 0,
StepsCompleted: 1,
StepsCompleted: &step,
Metrics: &commonv1.Metrics{
AvgMetrics: expectedMetrics,
},
Expand All @@ -466,12 +469,14 @@ func TestTrialsNonNumericMetrics(t *testing.T) {
expectedMetrics, err := structpb.NewStruct(expectedMetricsMap)
require.NoError(t, err)

step := int32(1)

trial, _ := createTestTrial(t, api, curUser)
_, err = api.ReportTrialMetrics(ctx, &apiv1.ReportTrialMetricsRequest{
Metrics: &trialv1.TrialMetrics{
TrialId: int32(trial.ID),
TrialRunId: 0,
StepsCompleted: 1,
StepsCompleted: &step,
Metrics: &commonv1.Metrics{
AvgMetrics: expectedMetrics,
},
Expand Down Expand Up @@ -628,12 +633,14 @@ func TestUnusualMetricNames(t *testing.T) {
expectedMetrics, err := structpb.NewStruct(expectedMetricsMap)
require.NoError(t, err)

step := int32(1)

trial, _ := createTestTrial(t, api, curUser)
_, err = api.ReportTrialValidationMetrics(ctx, &apiv1.ReportTrialValidationMetricsRequest{
ValidationMetrics: &trialv1.TrialMetrics{
TrialId: int32(trial.ID),
TrialRunId: 0,
StepsCompleted: 1,
StepsCompleted: &step,
Metrics: &commonv1.Metrics{
AvgMetrics: expectedMetrics,
},
Expand Down Expand Up @@ -1215,6 +1222,7 @@ func createTestTrialInferenceMetrics(ctx context.Context, t *testing.T, api *api
require.NoError(t, json.Unmarshal([]byte(
`{"inference": [{"a":1}, {"b":2}]}`,
), &trialMetrics))
step := int32(0)
for mType, metricsList := range trialMetrics {
for _, m := range metricsList {
metrics, err := structpb.NewStruct(m)
Expand All @@ -1223,7 +1231,7 @@ func createTestTrialInferenceMetrics(ctx context.Context, t *testing.T, api *api
&trialv1.TrialMetrics{
TrialId: id,
TrialRunId: int32(0),
StepsCompleted: int32(0),
StepsCompleted: &step,
Metrics: &commonv1.Metrics{
AvgMetrics: metrics,
},
Expand Down
4 changes: 2 additions & 2 deletions master/internal/db/postgres_experiments_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func TestCheckpointMetadata(t *testing.T) {
if tt.hasValidation {
m = &trialv1.TrialMetrics{
TrialId: int32(tr.ID),
StepsCompleted: stepsCompleted,
StepsCompleted: &stepsCompleted,
Metrics: &commonv1.Metrics{
AvgMetrics: &structpb.Struct{
Fields: map[string]*structpb.Value{
Expand Down Expand Up @@ -701,7 +701,7 @@ func TestDeleteExperiments(t *testing.T) {
createMetric := func(sc int32, mv float64, trID int) *trialv1.TrialMetrics {
m := &trialv1.TrialMetrics{
TrialId: int32(trID),
StepsCompleted: sc,
StepsCompleted: &sc,
Metrics: &commonv1.Metrics{
AvgMetrics: &structpb.Struct{
Fields: map[string]*structpb.Value{
Expand Down
3 changes: 2 additions & 1 deletion master/internal/db/postgres_model_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,11 @@ func TestModels(t *testing.T) {
// Which maybe has some metrics.
var m *trialv1.TrialMetrics
const metricValue = 1.0
step := int32(stepsCompleted)
if tt.hasValidation {
m = &trialv1.TrialMetrics{
TrialId: int32(tr.ID),
StepsCompleted: stepsCompleted,
StepsCompleted: &step,
Metrics: &commonv1.Metrics{
AvgMetrics: &structpb.Struct{
Fields: map[string]*structpb.Value{
Expand Down
2 changes: 1 addition & 1 deletion master/internal/db/postgres_test_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ func AddTrialValidationMetrics(
trialMetrics := trialv1.TrialMetrics{
TrialId: int32(tr.ID),
TrialRunId: int32(0),
StepsCompleted: stepsCompleted,
StepsCompleted: &stepsCompleted,
Metrics: &commonv1.Metrics{
AvgMetrics: &structpb.Struct{
Fields: map[string]*structpb.Value{
Expand Down
26 changes: 22 additions & 4 deletions master/internal/db/postgres_trial.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"math"
"regexp"
"slices"
"strings"
"time"

Expand Down Expand Up @@ -367,6 +368,18 @@ func (db *PgDB) updateTotalBatches(ctx context.Context, tx *sqlx.Tx, trialID int
return nil
}

func (db *PgDB) _addTrialProfilingMetricsTx(
ctx context.Context, tx *sqlx.Tx, m *trialv1.TrialMetrics, mGroup model.MetricGroup,
) error {
if err := checkTrialRunID(ctx, tx, m.TrialId, m.TrialRunId); err != nil {
return err
}

metrics := model.JSONObj(m.Metrics.AvgMetrics.AsMap())
_, err := db.addRawMetrics(ctx, tx, &metrics, tryAsTime(m.ReportTime), m.TrialRunId, m.TrialId, nil, mGroup)
return err
}

func (db *PgDB) _addTrialMetricsTx(
ctx context.Context, tx *sqlx.Tx, m *trialv1.TrialMetrics, mGroup model.MetricGroup,
) (rollbacks int, err error) {
Expand All @@ -377,7 +390,7 @@ func (db *PgDB) _addTrialMetricsTx(
return rollbacks, err
}

if rollbacks, err = rollbackMetrics(ctx, tx, m.TrialRunId, m.TrialId, m.StepsCompleted,
if rollbacks, err = rollbackMetrics(ctx, tx, m.TrialRunId, m.TrialId, m.GetStepsCompleted(),
mGroup); err != nil {
return rollbacks, err
}
Expand All @@ -390,7 +403,7 @@ func (db *PgDB) _addTrialMetricsTx(
}

metricRowID, addedMetrics, err := db.addMetricsWithMerge(ctx, tx,
mBody, m.TrialRunId, m.TrialId, m.StepsCompleted, mGroup)
mBody, tryAsTime(m.ReportTime), m.TrialRunId, m.TrialId, m.StepsCompleted, mGroup)
if err != nil {
return rollbacks, err
}
Expand Down Expand Up @@ -482,7 +495,7 @@ WHERE id = $1;
if err := setTrialBestValidation(
tx, int(m.TrialId),
int(m.TrialRunId),
int(m.StepsCompleted)); err != nil {
int(m.GetStepsCompleted())); err != nil {
return rollbacks, errors.Wrap(err, "updating trial best validation")
}
}
Expand All @@ -500,7 +513,12 @@ func (db *PgDB) addTrialMetrics(
}
return rollbacks, db.withTransaction(fmt.Sprintf("add trial metrics %s", mGroup),
func(tx *sqlx.Tx) error {
rollbacks, err = db._addTrialMetricsTx(ctx, tx, m, mGroup)
switch {
case slices.Contains(model.ProfilingMetricGroups, mGroup):
err = db._addTrialProfilingMetricsTx(ctx, tx, m, mGroup)
default:
rollbacks, err = db._addTrialMetricsTx(ctx, tx, m, mGroup)
}
return err
})
}
Expand Down
Loading

0 comments on commit 9f681af

Please sign in to comment.