Skip to content

Commit

Permalink
Merge pull request #54 from dran-dev/main
Browse files Browse the repository at this point in the history
Add new API ListenForPolicyViolations to replace Policy
  • Loading branch information
nvvfedorov authored Jan 11, 2024
2 parents 3c233ee + 1c239ce commit 26fbf85
Show file tree
Hide file tree
Showing 8 changed files with 220 additions and 88 deletions.
13 changes: 10 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
module github.com/NVIDIA/go-dcgm

go 1.16
go 1.21

require (
github.com/Masterminds/semver v1.5.0
github.com/bits-and-blooms/bitset v1.2.1
github.com/gorilla/mux v1.8.0
github.com/bits-and-blooms/bitset v1.13.0
github.com/gorilla/mux v1.8.1
github.com/stretchr/testify v1.8.4
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/objx v0.5.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
5 changes: 5 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,20 @@ github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3Q
github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y=
github.com/bits-and-blooms/bitset v1.2.1 h1:M+/hrU9xlMp7t4TyTDQW97d3tRPVuKFC6zBEK16QnXY=
github.com/bits-and-blooms/bitset v1.2.1/go.mod h1:gIdJ4wp64HaoK2YrL1Q5/N7Y16edYb8uY+O0FJTyyDA=
github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE=
github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
Expand Down
8 changes: 5 additions & 3 deletions pkg/dcgm/api.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dcgm

import (
"context"
"fmt"
"os"
"sync"
Expand Down Expand Up @@ -103,9 +104,10 @@ func HealthCheckByGpuId(gpuId uint) (DeviceHealth, error) {
return healthCheckByGpuId(gpuId)
}

// Policy sets GPU usage and error policies and notifies in case of any violations via callback functions
func Policy(gpuId uint, typ ...policyCondition) (<-chan PolicyViolation, error) {
return registerPolicy(gpuId, typ...)
// ListenForPolicyViolations sets GPU usage and error policies and notifies in case of any violations
func ListenForPolicyViolations(ctx context.Context, typ ...policyCondition) (<-chan PolicyViolation, error) {
groupId := GroupAllGPUs()
return registerPolicy(ctx, groupId, typ...)
}

// Introspect returns DCGM hostengine memory and CPU usage
Expand Down
3 changes: 3 additions & 0 deletions pkg/dcgm/bcast.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,8 @@ func (p *publisher) broadcast() {
}

func (p *publisher) closePublisher() {
for _, s := range p.subscriberList() {
p.remove(s)
}
p.close <- true
}
63 changes: 32 additions & 31 deletions pkg/dcgm/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ extern int violationNotify(void* p);
*/
import "C"
import (
"context"
"encoding/binary"
"fmt"
"log"
"math/rand"
"sync"
"time"
"unsafe"
Expand Down Expand Up @@ -262,7 +262,7 @@ func setPolicy(groupId GroupHandle, condition C.dcgmPolicyCondition_t, paramList
for _, key := range paramList {
conditionParam, exists := paramMap[policyIndex(key)]
if !exists {
return fmt.Errorf("Error: Invalid Policy condition, %v does not exist.\n", key)
return fmt.Errorf("Error: Invalid Policy condition, %v does not exist", key)
}
// set policy condition parameters
// set condition type (bool or longlong)
Expand All @@ -287,23 +287,11 @@ func setPolicy(groupId GroupHandle, condition C.dcgmPolicyCondition_t, paramList
return
}

func registerPolicy(gpuId uint, typ ...policyCondition) (<-chan PolicyViolation, error) {
func registerPolicy(ctx context.Context, groupId GroupHandle, typ ...policyCondition) (<-chan PolicyViolation, error) {
// init policy globals for internal API
makePolicyChannels()
makePolicyParmsMap()

name := fmt.Sprintf("policy%d", rand.Uint64())
groupId, err := CreateGroup(name)
if err != nil {
return nil, err
}

if err = AddToGroup(groupId, gpuId); err != nil {
return nil, err
}

// make a list of all callback channels
var channels []chan PolicyViolation
// make a list of policy conditions for setting their parameters
var paramKeys []policyIndex
// get all conditions to be set in setPolicy()
Expand All @@ -313,54 +301,67 @@ func registerPolicy(gpuId uint, typ ...policyCondition) (<-chan PolicyViolation,
case DbePolicy:
paramKeys = append(paramKeys, dbePolicyIndex)
condition |= C.DCGM_POLICY_COND_DBE
channels = append(channels, callbacks["dbe"])
case PCIePolicy:
paramKeys = append(paramKeys, pciePolicyIndex)
condition |= C.DCGM_POLICY_COND_PCI
channels = append(channels, callbacks["pcie"])
case MaxRtPgPolicy:
paramKeys = append(paramKeys, maxRtPgPolicyIndex)
condition |= C.DCGM_POLICY_COND_MAX_PAGES_RETIRED
channels = append(channels, callbacks["maxrtpg"])
case ThermalPolicy:
paramKeys = append(paramKeys, thermalPolicyIndex)
condition |= C.DCGM_POLICY_COND_THERMAL
channels = append(channels, callbacks["thermal"])
case PowerPolicy:
paramKeys = append(paramKeys, powerPolicyIndex)
condition |= C.DCGM_POLICY_COND_POWER
channels = append(channels, callbacks["power"])
case NvlinkPolicy:
paramKeys = append(paramKeys, nvlinkPolicyIndex)
condition |= C.DCGM_POLICY_COND_NVLINK
channels = append(channels, callbacks["nvlink"])
case XidPolicy:
paramKeys = append(paramKeys, xidPolicyIndex)
condition |= C.DCGM_POLICY_COND_XID
channels = append(channels, callbacks["xid"])
}
}

var err error
if err = setPolicy(groupId, condition, paramKeys); err != nil {
return nil, err
}

result := C.dcgmPolicyRegister(handle.handle, groupId.handle, C.dcgmPolicyCondition_t(condition), C.fpRecvUpdates(C.violationNotify), C.fpRecvUpdates(C.violationNotify))
var finishCallback unsafe.Pointer
result := C.dcgmPolicyRegister(handle.handle, groupId.handle, C.dcgmPolicyCondition_t(condition), C.fpRecvUpdates(C.violationNotify), C.fpRecvUpdates(finishCallback))

if err = errorString(result); err != nil {
return nil, &DcgmError{msg: C.GoString(C.errorString(result)), Code: result}
}
log.Println("Listening for violations...")

// merge
violation := make(chan PolicyViolation, len(channels))
violation := make(chan PolicyViolation, len(typ))
go func() {
for _, c := range channels {
val := <-c
violation <- val
defer func() {
log.Println("unregister policy violation...")
close(violation)
unregisterPolicy(groupId, condition)
}()
for {
select {
case dbe := <-callbacks["dbe"]:
violation <- dbe
case pcie := <-callbacks["pcie"]:
violation <- pcie
case maxrtpg := <-callbacks["maxrtpg"]:
violation <- maxrtpg
case thermal := <-callbacks["thermal"]:
violation <- thermal
case power := <-callbacks["power"]:
violation <- power
case nvlink := <-callbacks["nvlink"]:
violation <- nvlink
case xid := <-callbacks["xid"]:
violation <- xid
case <-ctx.Done():
return
}
}
DestroyGroup(groupId)
close(violation)
}()

return violation, err
Expand All @@ -370,7 +371,7 @@ func unregisterPolicy(groupId GroupHandle, condition C.dcgmPolicyCondition_t) {
result := C.dcgmPolicyUnregister(handle.handle, groupId.handle, condition)

if err := errorString(result); err != nil {
fmt.Errorf("Error unregistering policy: %s", err)
log.Println(fmt.Errorf("error unregistering policy: %s", err))
}
}

Expand Down
Loading

0 comments on commit 26fbf85

Please sign in to comment.