Skip to content

Commit

Permalink
Add ability to access GroupHandle and FieldHandle
Browse files Browse the repository at this point in the history
Signed-off-by: Rohit Arora <[email protected]>
  • Loading branch information
rohit-arora-dev committed Mar 15, 2024
1 parent 150a87b commit b23a382
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 7 deletions.
30 changes: 24 additions & 6 deletions pkg/dcgm/fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ type FieldMeta struct {

type FieldHandle struct{ handle C.dcgmFieldGrp_t }

func (f *FieldHandle) SetHandle(val uintptr) {
f.handle = C.dcgmGpuGrp_t(val)
}

func (f *FieldHandle) GetHandle() uintptr {
return uintptr(f.handle)
}

func FieldGroupCreate(fieldsGroupName string, fields []Short) (fieldsId FieldHandle, err error) {
var fieldsGroup C.dcgmFieldGrp_t
cfields := *(*[]C.ushort)(unsafe.Pointer(&fields))
Expand Down Expand Up @@ -66,7 +74,8 @@ func WatchFields(gpuId uint, fieldsGroup FieldHandle, groupName string) (groupId
return
}

result := C.dcgmWatchFields(handle.handle, group.handle, fieldsGroup.handle, C.longlong(defaultUpdateFreq), C.double(defaultMaxKeepAge), C.int(defaultMaxKeepSamples))
result := C.dcgmWatchFields(handle.handle, group.handle, fieldsGroup.handle, C.longlong(defaultUpdateFreq),
C.double(defaultMaxKeepAge), C.int(defaultMaxKeepSamples))
if err = errorString(result); err != nil {
return groupId, fmt.Errorf("Error watching fields: %s", err)
}
Expand All @@ -75,7 +84,9 @@ func WatchFields(gpuId uint, fieldsGroup FieldHandle, groupName string) (groupId
return group, nil
}

func WatchFieldsWithGroupEx(fieldsGroup FieldHandle, group GroupHandle, updateFreq int64, maxKeepAge float64, maxKeepSamples int32) error {
func WatchFieldsWithGroupEx(
fieldsGroup FieldHandle, group GroupHandle, updateFreq int64, maxKeepAge float64, maxKeepSamples int32,
) error {
result := C.dcgmWatchFields(handle.handle, group.handle, fieldsGroup.handle,
C.longlong(updateFreq), C.double(maxKeepAge), C.int(maxKeepSamples))

Expand Down Expand Up @@ -118,7 +129,8 @@ func EntityGetLatestValues(entityGroup Field_Entity_Group, entityId uint, fields
values := make([]C.dcgmFieldValue_v1, len(fields))
cfields := (*C.ushort)(unsafe.Pointer(&fields[0]))

result := C.dcgmEntityGetLatestValues(handle.handle, C.dcgm_field_entity_group_t(entityGroup), C.int(entityId), cfields, C.uint(len(fields)), &values[0])
result := C.dcgmEntityGetLatestValues(handle.handle, C.dcgm_field_entity_group_t(entityGroup), C.int(entityId),
cfields, C.uint(len(fields)), &values[0])
if result != C.DCGM_ST_OK {
return nil, &DcgmError{msg: C.GoString(C.errorString(result)), Code: result}
}
Expand All @@ -132,10 +144,14 @@ func EntitiesGetLatestValues(entities []GroupEntityPair, fields []Short, flags u
cEntities := make([]C.dcgmGroupEntityPair_t, len(entities))
cPtrEntities := *(*[]C.dcgmGroupEntityPair_t)(unsafe.Pointer(&cEntities))
for i, entity := range entities {
cEntities[i] = C.dcgmGroupEntityPair_t{C.dcgm_field_entity_group_t(entity.EntityGroupId), C.dcgm_field_eid_t(entity.EntityId)}
cEntities[i] = C.dcgmGroupEntityPair_t{
C.dcgm_field_entity_group_t(entity.EntityGroupId),
C.dcgm_field_eid_t(entity.EntityId),
}
}

result := C.dcgmEntitiesGetLatestValues(handle.handle, &cPtrEntities[0], C.uint(len(entities)), cfields, C.uint(len(fields)), C.uint(flags), &values[0])
result := C.dcgmEntitiesGetLatestValues(handle.handle, &cPtrEntities[0], C.uint(len(entities)), cfields,
C.uint(len(fields)), C.uint(flags), &values[0])
if err := errorString(result); err != nil {
return nil, &DcgmError{msg: C.GoString(C.errorString(result)), Code: result}
}
Expand Down Expand Up @@ -215,7 +231,9 @@ func toFieldValue_v2(cfields []C.dcgmFieldValue_v2) []FieldValue_v2 {
return fields
}

func dcgmFieldValue_v1ToFieldValue_v2(fieldEntityGroup Field_Entity_Group, entityId uint, cfields []C.dcgmFieldValue_v1) []FieldValue_v2 {
func dcgmFieldValue_v1ToFieldValue_v2(
fieldEntityGroup Field_Entity_Group, entityId uint, cfields []C.dcgmFieldValue_v1,
) []FieldValue_v2 {
fields := make([]FieldValue_v2, len(cfields))
for i, f := range cfields {
fields[i] = FieldValue_v2{
Expand Down
19 changes: 19 additions & 0 deletions pkg/dcgm/fields_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package dcgm

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestFieldHandle(t *testing.T) {
fh := FieldHandle{}
assert.Equal(t, uintptr(0), fh.GetHandle(), "value mismatch")

inputs := []uintptr{1000, 0, 1, 10, 11, 50, 100, 1939902, 9992932938239, 999999999999999999}

for _, input := range inputs {
fh.SetHandle(input)
assert.Equal(t, input, fh.GetHandle(), "values mismatch")
}
}
11 changes: 10 additions & 1 deletion pkg/dcgm/gpu_group.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ const (

type GroupHandle struct{ handle C.dcgmGpuGrp_t }

func (g *GroupHandle) SetHandle(val uintptr) {
g.handle = C.dcgmGpuGrp_t(val)
}

func (g *GroupHandle) GetHandle() uintptr {
return uintptr(g.handle)
}

func GroupAllGPUs() GroupHandle {
return GroupHandle{C.DCGM_GROUP_ALL_GPUS}
}
Expand Down Expand Up @@ -67,7 +75,8 @@ func AddLinkEntityToGroup(groupId GroupHandle, index uint, parentId uint) (err e
}

func AddEntityToGroup(groupId GroupHandle, entityGroupId Field_Entity_Group, entityId uint) (err error) {
result := C.dcgmGroupAddEntity(handle.handle, groupId.handle, C.dcgm_field_entity_group_t(entityGroupId), C.uint(entityId))
result := C.dcgmGroupAddEntity(handle.handle, groupId.handle, C.dcgm_field_entity_group_t(entityGroupId),
C.uint(entityId))
if err = errorString(result); err != nil {
return fmt.Errorf("Error adding entity group type %v, entity %v to group: %s", entityGroupId, entityId, err)
}
Expand Down
19 changes: 19 additions & 0 deletions pkg/dcgm/gpu_group_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package dcgm

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestGroupHandle(t *testing.T) {
gh := GroupHandle{}
assert.Equal(t, uintptr(0), gh.GetHandle(), "value mismatch")

inputs := []uintptr{1000, 0, 1, 10, 11, 50, 100, 1939902, 9992932938239, 999999999999999999}

for _, input := range inputs {
gh.SetHandle(input)
assert.Equal(t, input, gh.GetHandle(), "values mismatch")
}
}

0 comments on commit b23a382

Please sign in to comment.