Skip to content

Commit

Permalink
fix: Revert to get_checkpoints.sql call to enable NaN & Infinity valu…
Browse files Browse the repository at this point in the history
…es in searcher metric (#9440)
  • Loading branch information
AmanuelAaron authored Jun 4, 2024
1 parent d50433d commit 13a5142
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 9 deletions.
15 changes: 9 additions & 6 deletions master/internal/api_checkpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,15 @@ func (a *apiServer) GetCheckpoint(
}

resp := &apiv1.GetCheckpointResponse{}
resp.Checkpoint = &checkpointv1.Checkpoint{}

ckpt, err := internaldb.GetCheckpoint(ctx, req.CheckpointUuid)
if err != nil {
return resp, errors.Wrapf(err, "error fetching checkpoint %s from database", req.CheckpointUuid)
// We don't use Bun here as Bun's marshaling/unmarshaling does not account for "NaN",
// "Infinity", and "-Infinity"
if err := a.m.db.QueryProto(
"get_checkpoint", resp.Checkpoint, req.CheckpointUuid); err != nil {
return resp,
errors.Wrapf(err, "error fetching checkpoint %s from database", req.CheckpointUuid)
}
resp.Checkpoint = ckpt
return resp, nil
}

Expand Down Expand Up @@ -476,8 +479,8 @@ func (a *apiServer) PostCheckpointMetadata(
return nil, err
}

currCheckpoint, err := internaldb.GetCheckpoint(ctx, req.Checkpoint.Uuid)
if err != nil {
currCheckpoint := &checkpointv1.Checkpoint{}
if err := a.m.db.QueryProto("get_checkpoint", currCheckpoint, req.Checkpoint.Uuid); err != nil {
return nil,
errors.Wrapf(err, "error fetching checkpoint %s from database", req.Checkpoint.Uuid)
}
Expand Down
106 changes: 106 additions & 0 deletions master/internal/api_checkpoint_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"encoding/json"
"fmt"
"math"
"testing"
"time"

Expand Down Expand Up @@ -328,6 +329,111 @@ func TestCheckpointReturned(t *testing.T) {
})
}

func TestGetCheckpointNaNInfinityValues(t *testing.T) {
tests := map[string]struct {
metric string
metricValue float64
}{
"NaNCase": {
metric: "NaN",
metricValue: math.NaN(),
},
"InfinityCase": {
metric: "Infinity",
metricValue: math.Inf(1),
},
"-InfinityCase": {
metric: "-Infinity",
metricValue: math.Inf(-1),
},
}
for testCase, testVars := range tests {
t.Run(testCase, func(t *testing.T) {
// This tries to test all places where we will return a checkpointv1.Checkpoint.
api, curUser, ctx := setupAPITest(t, nil)
trial, task := createTestTrial(t, api, curUser)

checkpointStorage, err := structpb.NewStruct(map[string]any{
"type": "shared_fs",
"host_path": uuid.New().String(),
"propagation": "private",
})
require.NoError(t, err)

reportResponse, err := api.RunPrepareForReporting(ctx, &apiv1.RunPrepareForReportingRequest{
RunId: int32(trial.ID),
})
require.NoError(t, err)
require.Nil(t, reportResponse.StorageId)

reportResponse, err = api.RunPrepareForReporting(ctx, &apiv1.RunPrepareForReportingRequest{
RunId: int32(trial.ID),
CheckpointStorage: checkpointStorage,
})
require.NoError(t, err)
require.NotNil(t, reportResponse.StorageId)
checkpointMeta, err := structpb.NewStruct(map[string]any{
"steps_completed": 1,
})
require.NoError(t, err)
checkpointID := uuid.New().String()
checkpoint := &checkpointv1.Checkpoint{
TaskId: string(task.TaskID),
AllocationId: nil,
Uuid: checkpointID,
ReportTime: timestamppb.New(time.Now().UTC().Truncate(time.Millisecond)),
Resources: map[string]int64{"x": 128, "y/": 0},
Metadata: checkpointMeta,
State: checkpointv1.State_STATE_COMPLETED,
StorageId: reportResponse.StorageId,
}
require.NoError(t, err)
int32TrialID := int32(trial.ID)
int32ExperimentID := int32(trial.ExperimentID)
checkpoint.Training = &checkpointv1.CheckpointTrainingMetadata{
TrialId: &int32TrialID,
ExperimentId: &int32ExperimentID,
TrainingMetrics: &commonv1.Metrics{},
ValidationMetrics: &commonv1.Metrics{
AvgMetrics: &structpb.Struct{
Fields: map[string]*structpb.Value{
"loss": structpb.NewStringValue(testVars.metric),
},
},
},
SearcherMetric: &testVars.metricValue,
}
_, err = api.ReportCheckpoint(ctx, &apiv1.ReportCheckpointRequest{
Checkpoint: checkpoint,
})
require.NoError(t, err)

_, err = api.ReportTrialValidationMetrics(ctx, &apiv1.ReportTrialValidationMetricsRequest{
ValidationMetrics: &trialv1.TrialMetrics{
TrialId: int32TrialID,
TrialRunId: 0,
StepsCompleted: ptrs.Ptr(int32(1)),
ReportTime: timestamppb.New(time.Now().UTC().Truncate(time.Millisecond)),
Metrics: checkpoint.Training.ValidationMetrics,
},
})
require.NoError(t, err)

resp, err := api.GetCheckpoint(ctx, &apiv1.GetCheckpointRequest{
CheckpointUuid: checkpointID,
})
require.NoError(t, err)
require.Equal(t, testVars.metric,
resp.Checkpoint.Training.ValidationMetrics.AvgMetrics.Fields["loss"].GetStringValue())
if testVars.metric == "NaN" {
require.True(t, math.IsNaN(*resp.Checkpoint.Training.SearcherMetric))
} else {
require.Equal(t, testVars.metricValue, *resp.Checkpoint.Training.SearcherMetric)
}
})
}
}

func TestCheckpointRemoveFilesPrefixAndEmpty(t *testing.T) {
api, _, ctx := setupAPITest(t, nil)
_, err := api.CheckpointsRemoveFiles(ctx, &apiv1.CheckpointsRemoveFilesRequest{
Expand Down
8 changes: 5 additions & 3 deletions master/internal/api_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -632,10 +632,12 @@ func (a *apiServer) PostModelVersion(
}

// make sure the checkpoint exists
c, getCheckpointErr := db.GetCheckpoint(ctx, req.CheckpointUuid)
if getCheckpointErr == db.ErrNotFound {
c := &checkpointv1.Checkpoint{}

switch getCheckpointErr := a.m.db.QueryProto("get_checkpoint", c, req.CheckpointUuid); {
case getCheckpointErr == db.ErrNotFound:
return nil, api.NotFoundErrs("checkpoint", req.CheckpointUuid, true)
} else if getCheckpointErr != nil {
case getCheckpointErr != nil:
return nil, getCheckpointErr
}

Expand Down
1 change: 1 addition & 0 deletions master/internal/db/postgres_experiments.go
Original file line number Diff line number Diff line change
Expand Up @@ -1034,6 +1034,7 @@ WHERE id = $1`, id)

// GetCheckpoint gets checkpointv1.Checkpoint from the database by UUID.
// Can be moved to master/internal/checkpoints once db/postgres_model_intg_test is bunified.
// WARNING: Function does not account for "NaN", "Infinity", or "-Infinity" due to Bun unmarshallling.
func GetCheckpoint(ctx context.Context, checkpointUUID string) (*checkpointv1.Checkpoint, error) {
var retCkpt1 checkpointv1.Checkpoint
err := Bun().NewSelect().
Expand Down
3 changes: 3 additions & 0 deletions master/static/srv/get_checkpoint.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SELECT *
FROM proto_checkpoints_view c
WHERE c.uuid = $1

0 comments on commit 13a5142

Please sign in to comment.